From d4074c5ae0f53c8b329b740ef87c72f3595d4b35 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Nov 2020 14:15:42 +0000 Subject: [PATCH 1/8] Added abstract precision plugin to expose clip_gradients function, use within accelerator to clip gradients --- pytorch_lightning/accelerators/accelerator.py | 25 ++++++++----------- pytorch_lightning/plugins/apex.py | 10 +++++++- pytorch_lightning/plugins/native_amp.py | 9 ++++++- pytorch_lightning/plugins/precision_plugin.py | 22 ++++++++++++++++ 4 files changed, 49 insertions(+), 17 deletions(-) create mode 100644 pytorch_lightning/plugins/precision_plugin.py diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 3b762e08ed5e6..27b8ecde87ade 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -18,18 +18,13 @@ import torch -from pytorch_lightning.utilities import AMPType, rank_zero_warn +from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict import torch.distributed as torch_distrib from pytorch_lightning import _logger as log -try: - from apex import amp -except ImportError: - amp = None - if torch.distributed.is_available(): from torch.distributed import ReduceOp else: @@ -37,7 +32,6 @@ class ReduceOp: SUM = None EPSILON = 1e-6 -EPSILON_FP16 = 1e-5 class Accelerator(object): @@ -149,17 +143,19 @@ def _clip_gradients(self, optimizer, clip_val=None): grad_clip_val = clip_val grad_clip_val = float(grad_clip_val) - # this code is a modification of torch.nn.utils.clip_grad_norm_ - # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md if grad_clip_val <= 0: return - model = self.trainer.get_model() - if self.trainer.amp_backend == AMPType.APEX: - parameters = amp.master_params(optimizer) + if self.trainer.amp_backend: + self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer) else: - parameters = model.parameters() + self._clip_gradients_with_tpu_support(grad_clip_val) + def _clip_gradients_with_tpu_support(self, grad_clip_val): + # this code is a modification of torch.nn.utils.clip_grad_norm_ + # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md + model = self.trainer.get_model() + parameters = model.parameters() max_norm = grad_clip_val norm_type = float(2.0) @@ -176,8 +172,7 @@ def _clip_gradients(self, optimizer, clip_val=None): torch.norm(p.grad.data.to(device), norm_type, out=out[i]) total_norm = torch.norm(out, norm_type) - eps = EPSILON_FP16 if self.trainer.precision == 16 else EPSILON - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + EPSILON) clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) for p in parameters: p.grad.data.mul_(clip_coef.to(p.grad.data.device)) diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py index 0c8665e3719f3..b79decadea244 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/apex.py @@ -13,9 +13,11 @@ # limitations under the License. from typing import List, Tuple +import torch from torch.optim.optimizer import Optimizer from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities import AMPType @@ -25,7 +27,7 @@ amp = None -class ApexPlugin: +class ApexPlugin(PrecisionPlugin): def __init__(self, trainer=None): self.trainer = trainer @@ -98,3 +100,9 @@ def configure_apex(self, amp, model, optimizers, amp_level): """ model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) return model, optimizers + + def clip_gradients(self, grad_clip_val, model, optimizer): + parameters = amp.master_params(optimizer) + max_norm = grad_clip_val + norm_type = float(2.0) + torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/native_amp.py index 98bc8dfc87d25..887ab431c1194 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/native_amp.py @@ -14,8 +14,10 @@ import torch +from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin -class NativeAMPPlugin: + +class NativeAMPPlugin(PrecisionPlugin): def __init__(self, trainer=None): """ @@ -51,3 +53,8 @@ def training_step(self, fx, args): with torch.cuda.amp.autocast(): output = fx(*args) return output + + def clip_gradients(self, grad_clip_val, model, optimizer): + max_norm = grad_clip_val + norm_type = float(2.0) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/precision_plugin.py b/pytorch_lightning/plugins/precision_plugin.py new file mode 100644 index 0000000000000..1ad24140fceab --- /dev/null +++ b/pytorch_lightning/plugins/precision_plugin.py @@ -0,0 +1,22 @@ +import abc + + +class PrecisionPlugin(abc.ABC): + """ + Abstract class to extend for precision support (32/16 etc). + + This is extended to cover any specific logic required for precision support such as AMP/APEX or sharded + training. + """ + + def connect(self, model, optimizers): + raise NotImplementedError + + def training_step(self, fx, args): + raise NotImplementedError + + def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): + raise NotImplementedError + + def clip_gradients(self, grad_clip_val, model, optimizer): + raise NotImplementedError From 7a818f9100339ad23040ee8e4513ac6aedaf7d7c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Nov 2020 14:58:46 +0000 Subject: [PATCH 2/8] Exclude model from override, keep optimizer (needed for sharded clip gradients), add override for O2 support apex --- pytorch_lightning/plugins/apex.py | 34 +++++++++++++++++-- pytorch_lightning/plugins/native_amp.py | 3 +- pytorch_lightning/plugins/precision_plugin.py | 2 +- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py index b79decadea244..a2d00dc06e6df 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/apex.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import List, Tuple import torch @@ -26,6 +27,8 @@ except ImportError: amp = None +FP16_EPSILON = 1e-5 + class ApexPlugin(PrecisionPlugin): @@ -101,8 +104,33 @@ def configure_apex(self, amp, model, optimizers, amp_level): model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) return model, optimizers - def clip_gradients(self, grad_clip_val, model, optimizer): - parameters = amp.master_params(optimizer) + def clip_gradients(self, grad_clip_val, optimizer): + """ + This code is a modification of torch.nn.utils.clip_grad_norm_ using a higher epsilon for fp16 weights. + This is important when setting amp_level to O2, and the master weights are in fp16. + Args: + grad_clip_val: Maximum norm of gradients. + optimizer: Optimizer with gradients that will be clipped. + """ + model = self.trainer.get_model() + parameters = model.parameters() max_norm = grad_clip_val norm_type = float(2.0) - torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type) + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == math.inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + clip_coef = max_norm / (total_norm + FP16_EPSILON) + if clip_coef < 1: + for p in parameters: + p.grad.detach().mul_(clip_coef.to(p.grad.device)) diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/native_amp.py index 887ab431c1194..7a0d342022bc0 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/native_amp.py @@ -54,7 +54,8 @@ def training_step(self, fx, args): output = fx(*args) return output - def clip_gradients(self, grad_clip_val, model, optimizer): + def clip_gradients(self, grad_clip_val, optimizer): + model = self.trainer.get_model() max_norm = grad_clip_val norm_type = float(2.0) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/precision_plugin.py b/pytorch_lightning/plugins/precision_plugin.py index 1ad24140fceab..6352f612f2532 100644 --- a/pytorch_lightning/plugins/precision_plugin.py +++ b/pytorch_lightning/plugins/precision_plugin.py @@ -18,5 +18,5 @@ def training_step(self, fx, args): def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): raise NotImplementedError - def clip_gradients(self, grad_clip_val, model, optimizer): + def clip_gradients(self, grad_clip_val, optimizer): raise NotImplementedError From ac92cf35482d72cc073817465274b245962ce064 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Nov 2020 15:05:56 +0000 Subject: [PATCH 3/8] Fix doc --- pytorch_lightning/plugins/apex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py index a2d00dc06e6df..98fd5a54c1de5 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/apex.py @@ -106,7 +106,7 @@ def configure_apex(self, amp, model, optimizers, amp_level): def clip_gradients(self, grad_clip_val, optimizer): """ - This code is a modification of torch.nn.utils.clip_grad_norm_ using a higher epsilon for fp16 weights. + This code is a modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights. This is important when setting amp_level to O2, and the master weights are in fp16. Args: grad_clip_val: Maximum norm of gradients. From d3e6e9e718d851153f3439fee0d72202ecaa9098 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Nov 2020 15:20:17 +0000 Subject: [PATCH 4/8] Applied codereview changes --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/plugins/apex.py | 2 +- pytorch_lightning/plugins/native_amp.py | 4 +--- pytorch_lightning/plugins/precision_plugin.py | 13 +++++++++++++ 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 27b8ecde87ade..7e2dfb43ffc83 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -157,7 +157,7 @@ def _clip_gradients_with_tpu_support(self, grad_clip_val): model = self.trainer.get_model() parameters = model.parameters() max_norm = grad_clip_val - norm_type = float(2.0) + norm_type = 2.0 if isinstance(parameters, torch.Tensor): parameters = [parameters] diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py index 98fd5a54c1de5..6b94b8c4d3834 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/apex.py @@ -115,7 +115,7 @@ def clip_gradients(self, grad_clip_val, optimizer): model = self.trainer.get_model() parameters = model.parameters() max_norm = grad_clip_val - norm_type = float(2.0) + norm_type = 2.0 if isinstance(parameters, torch.Tensor): parameters = [parameters] diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/native_amp.py index 7a0d342022bc0..c566ab75f3f01 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/native_amp.py @@ -56,6 +56,4 @@ def training_step(self, fx, args): def clip_gradients(self, grad_clip_val, optimizer): model = self.trainer.get_model() - max_norm = grad_clip_val - norm_type = float(2.0) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm, norm_type=norm_type) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=2.0) diff --git a/pytorch_lightning/plugins/precision_plugin.py b/pytorch_lightning/plugins/precision_plugin.py index 6352f612f2532..25a242c9993dd 100644 --- a/pytorch_lightning/plugins/precision_plugin.py +++ b/pytorch_lightning/plugins/precision_plugin.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import abc From 90460d6dbb9b3d2c3cf23f6db5a0e0005a1564f1 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Nov 2020 15:25:33 +0000 Subject: [PATCH 5/8] Refactored clip function to encapsulate tpu changes with tpu accelerator. Default to standard clip function for vanilla torch --- pytorch_lightning/accelerators/accelerator.py | 38 ++----------------- .../accelerators/tpu_accelerator.py | 32 ++++++++++++++-- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 7e2dfb43ffc83..04cf8359aa8d5 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import math from enum import Enum from typing import Any, Optional, Union @@ -31,8 +30,6 @@ class ReduceOp: SUM = None -EPSILON = 1e-6 - class Accelerator(object): @@ -133,10 +130,6 @@ def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) def clip_gradients(self, optimizer, clip_val=None): - # TODO: separate TPU case from here - self._clip_gradients(optimizer, clip_val) - - def _clip_gradients(self, optimizer, clip_val=None): # use the trainer's clip val if none passed grad_clip_val = self.trainer.gradient_clip_val if clip_val is not None: @@ -145,37 +138,14 @@ def _clip_gradients(self, optimizer, clip_val=None): if grad_clip_val <= 0: return + self._clip_gradients(optimizer, clip_val) + def _clip_gradients(self, optimizer, grad_clip_val): if self.trainer.amp_backend: self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer) else: - self._clip_gradients_with_tpu_support(grad_clip_val) - - def _clip_gradients_with_tpu_support(self, grad_clip_val): - # this code is a modification of torch.nn.utils.clip_grad_norm_ - # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md - model = self.trainer.get_model() - parameters = model.parameters() - max_norm = grad_clip_val - norm_type = 2.0 - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - - if norm_type == math.inf: - total_norm = max(p.grad.data.abs().max() for p in parameters) - else: - device = parameters[0].device - out = torch.empty(len(parameters), device=device) - for i, p in enumerate(parameters): - torch.norm(p.grad.data.to(device), norm_type, out=out[i]) - total_norm = torch.norm(out, norm_type) - - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + EPSILON) - clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) - for p in parameters: - p.grad.data.mul_(clip_coef.to(p.grad.data.device)) + model = self.trainer.get_model() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=2.0) def on_train_epoch_end(self, outputs): pass diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 15386b133f8bd..44ed3b361e5d4 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import io +import math import os import re from typing import Optional, Union, Any @@ -35,6 +36,8 @@ import torch_xla.distributed.parallel_loader as xla_pl import torch_xla.distributed.xla_multiprocessing as xmp +EPSILON = 1e-6 + class TPUAccelerator(Accelerator): @@ -261,10 +264,31 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): using_lbfgs=is_lbfgs ) - def clip_gradients(self, optimizer, clip_val=None): - # apply clip gradients - # TODO: separate TPU case from here - self._clip_gradients(optimizer, clip_val) + def _clip_gradients(self, optimizer, grad_clip_val): + # this code is a modification of torch.nn.utils.clip_grad_norm_ + # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md + model = self.trainer.get_model() + parameters = model.parameters() + max_norm = grad_clip_val + norm_type = 2.0 + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + + if norm_type == math.inf: + total_norm = max(p.grad.data.abs().max() for p in parameters) + else: + device = parameters[0].device + out = torch.empty(len(parameters), device=device) + for i, p in enumerate(parameters): + torch.norm(p.grad.data.to(device), norm_type, out=out[i]) + total_norm = torch.norm(out, norm_type) + + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + EPSILON) + clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) + for p in parameters: + p.grad.data.mul_(clip_coef.to(p.grad.data.device)) def barrier(self, name: Optional[str] = None): torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}") From 9d3b82cf0187fc8cb569e3a5ec2efe2a244af5c5 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Nov 2020 15:36:07 +0000 Subject: [PATCH 6/8] Pass correct grad clip val --- pytorch_lightning/accelerators/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 04cf8359aa8d5..018ce548dee2d 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -138,7 +138,7 @@ def clip_gradients(self, optimizer, clip_val=None): if grad_clip_val <= 0: return - self._clip_gradients(optimizer, clip_val) + self._clip_gradients(optimizer, grad_clip_val) def _clip_gradients(self, optimizer, grad_clip_val): if self.trainer.amp_backend: From 205967aeaf492cb27219bd4fea2373b43e55aa4c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Nov 2020 15:51:34 +0000 Subject: [PATCH 7/8] Moved var to property --- pytorch_lightning/accelerators/tpu_accelerator.py | 8 +++++--- pytorch_lightning/plugins/apex.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 44ed3b361e5d4..f8b3aea4b8582 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -36,8 +36,6 @@ import torch_xla.distributed.parallel_loader as xla_pl import torch_xla.distributed.xla_multiprocessing as xmp -EPSILON = 1e-6 - class TPUAccelerator(Accelerator): @@ -285,7 +283,7 @@ def _clip_gradients(self, optimizer, grad_clip_val): torch.norm(p.grad.data.to(device), norm_type, out=out[i]) total_norm = torch.norm(out, norm_type) - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + EPSILON) + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon) clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) for p in parameters: p.grad.data.mul_(clip_coef.to(p.grad.data.device)) @@ -367,3 +365,7 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return tensor + + @property + def norm_clipping_epsilon(self): + return 1e-6 diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py index 6b94b8c4d3834..53fd7ee156e53 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/apex.py @@ -27,8 +27,6 @@ except ImportError: amp = None -FP16_EPSILON = 1e-5 - class ApexPlugin(PrecisionPlugin): @@ -130,7 +128,11 @@ def clip_gradients(self, grad_clip_val, optimizer): else: total_norm = torch.norm( torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) - clip_coef = max_norm / (total_norm + FP16_EPSILON) + clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon) if clip_coef < 1: for p in parameters: p.grad.detach().mul_(clip_coef.to(p.grad.device)) + + @property + def norm_clipping_epsilon(self): + return 1e-5 From 58190e408b2d63d29ce1c33c89ab6ddc7580046d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Nov 2020 16:47:20 +0000 Subject: [PATCH 8/8] Apply code review suggestions --- pytorch_lightning/accelerators/accelerator.py | 9 +++++---- .../accelerators/tpu_accelerator.py | 17 +++++++---------- pytorch_lightning/plugins/apex.py | 19 ++++++++----------- pytorch_lightning/plugins/native_amp.py | 6 ++++-- pytorch_lightning/plugins/precision_plugin.py | 5 ++++- 5 files changed, 28 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 018ce548dee2d..a0d8f6f21a2f7 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -16,6 +16,7 @@ from typing import Any, Optional, Union import torch +from torch.optim import Optimizer from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.apply_func import move_data_to_device @@ -140,12 +141,12 @@ def clip_gradients(self, optimizer, clip_val=None): return self._clip_gradients(optimizer, grad_clip_val) - def _clip_gradients(self, optimizer, grad_clip_val): + def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): if self.trainer.amp_backend: - self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer) + self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer, norm_type) else: model = self.trainer.get_model() - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=2.0) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type) def on_train_epoch_end(self, outputs): pass @@ -166,7 +167,7 @@ def setup_optimizers(self, model): self.trainer.optimizer_frequencies = optimizer_frequencies def init_ddp_connection( - self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True + self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True ) -> None: os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index f8b3aea4b8582..54ee57b74a16a 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -19,6 +19,7 @@ import torch import torch.multiprocessing as mp +from torch.optim import Optimizer from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp @@ -262,26 +263,22 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): using_lbfgs=is_lbfgs ) - def _clip_gradients(self, optimizer, grad_clip_val): + def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): # this code is a modification of torch.nn.utils.clip_grad_norm_ # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md model = self.trainer.get_model() parameters = model.parameters() max_norm = grad_clip_val - norm_type = 2.0 if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) - if norm_type == math.inf: - total_norm = max(p.grad.data.abs().max() for p in parameters) - else: - device = parameters[0].device - out = torch.empty(len(parameters), device=device) - for i, p in enumerate(parameters): - torch.norm(p.grad.data.to(device), norm_type, out=out[i]) - total_norm = torch.norm(out, norm_type) + device = parameters[0].device + out = torch.empty(len(parameters), device=device) + for i, p in enumerate(parameters): + torch.norm(p.grad.data.to(device), norm_type, out=out[i]) + total_norm = torch.norm(out, norm_type) clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon) clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py index 53fd7ee156e53..654f7202fb9d1 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/apex.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import List, Tuple +from typing import List, Tuple, Union import torch from torch.optim.optimizer import Optimizer @@ -102,32 +102,29 @@ def configure_apex(self, amp, model, optimizers, amp_level): model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) return model, optimizers - def clip_gradients(self, grad_clip_val, optimizer): + def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): """ This code is a modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights. This is important when setting amp_level to O2, and the master weights are in fp16. Args: grad_clip_val: Maximum norm of gradients. optimizer: Optimizer with gradients that will be clipped. + norm_type: (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. """ model = self.trainer.get_model() parameters = model.parameters() - max_norm = grad_clip_val - norm_type = 2.0 + max_norm = float(grad_clip_val) if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] - max_norm = float(max_norm) - norm_type = float(norm_type) + if len(parameters) == 0: return torch.tensor(0.) device = parameters[0].grad.device - if norm_type == math.inf: - total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) - else: - total_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + total_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon) if clip_coef < 1: for p in parameters: diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/native_amp.py index c566ab75f3f01..1a6649986132c 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/native_amp.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union import torch +from torch.optim import Optimizer from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin @@ -54,6 +56,6 @@ def training_step(self, fx, args): output = fx(*args) return output - def clip_gradients(self, grad_clip_val, optimizer): + def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): model = self.trainer.get_model() - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=2.0) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/precision_plugin.py b/pytorch_lightning/plugins/precision_plugin.py index 25a242c9993dd..0102f677391ff 100644 --- a/pytorch_lightning/plugins/precision_plugin.py +++ b/pytorch_lightning/plugins/precision_plugin.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +from typing import Union + +from torch.optim import Optimizer class PrecisionPlugin(abc.ABC): @@ -31,5 +34,5 @@ def training_step(self, fx, args): def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): raise NotImplementedError - def clip_gradients(self, grad_clip_val, optimizer): + def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): raise NotImplementedError