Skip to content

Commit

Permalink
Sharded Accelerator 1/n: Expose clip gradients to plugins via abstrac…
Browse files Browse the repository at this point in the history
…t class (#4639)

* Added abstract precision plugin to expose clip_gradients function, use within accelerator to clip gradients

* Exclude model from override, keep optimizer (needed for sharded clip gradients), add override for O2 support apex

* Fix doc

* Applied codereview changes

* Refactored clip function to encapsulate tpu changes with tpu accelerator. Default to standard clip function for vanilla torch

* Pass correct grad clip val

* Moved var to property

* Apply code review suggestions
  • Loading branch information
SeanNaren authored and rohitgr7 committed Nov 21, 2020
1 parent e6857de commit d0b8d10
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 50 deletions.
52 changes: 9 additions & 43 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
from enum import Enum
from typing import Any, Optional, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log

try:
from apex import amp
except ImportError:
amp = None

if torch.distributed.is_available():
from torch.distributed import ReduceOp
else:
class ReduceOp:
SUM = None

EPSILON = 1e-6
EPSILON_FP16 = 1e-5


class Accelerator(object):

Expand Down Expand Up @@ -139,48 +131,22 @@ def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)

def clip_gradients(self, optimizer, clip_val=None):
# TODO: separate TPU case from here
self._clip_gradients(optimizer, clip_val)

def _clip_gradients(self, optimizer, clip_val=None):
# use the trainer's clip val if none passed
grad_clip_val = self.trainer.gradient_clip_val
if clip_val is not None:
grad_clip_val = clip_val
grad_clip_val = float(grad_clip_val)

# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
if grad_clip_val <= 0:
return
self._clip_gradients(optimizer, grad_clip_val)

model = self.trainer.get_model()
if self.trainer.amp_backend == AMPType.APEX:
parameters = amp.master_params(optimizer)
def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0):
if self.trainer.amp_backend:
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer, norm_type)
else:
parameters = model.parameters()

max_norm = grad_clip_val
norm_type = float(2.0)

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))

if norm_type == math.inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
device = parameters[0].device
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

eps = EPSILON_FP16 if self.trainer.precision == 16 else EPSILON
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
model = self.trainer.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type)

def on_train_epoch_end(self, outputs):
pass
Expand All @@ -201,7 +167,7 @@ def setup_optimizers(self, model):
self.trainer.optimizer_frequencies = optimizer_frequencies

def init_ddp_connection(
self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
) -> None:
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
Expand Down
31 changes: 27 additions & 4 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import math
import os
import re
from typing import Optional, Union, Any

import torch
import torch.multiprocessing as mp
from torch.optim import Optimizer

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
Expand Down Expand Up @@ -261,10 +263,27 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
using_lbfgs=is_lbfgs
)

def clip_gradients(self, optimizer, clip_val=None):
# apply clip gradients
# TODO: separate TPU case from here
self._clip_gradients(optimizer, clip_val)
def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0):
# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
model = self.trainer.get_model()
parameters = model.parameters()
max_norm = grad_clip_val

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))

device = parameters[0].device
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))

def barrier(self, name: Optional[str] = None):
torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}")
Expand Down Expand Up @@ -343,3 +362,7 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return tensor

@property
def norm_clipping_epsilon(self):
return 1e-6
39 changes: 37 additions & 2 deletions pytorch_lightning/plugins/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import math
from typing import List, Tuple, Union

import torch
from torch.optim.optimizer import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities import AMPType

Expand All @@ -25,7 +28,7 @@
amp = None


class ApexPlugin:
class ApexPlugin(PrecisionPlugin):

def __init__(self, trainer=None):
self.trainer = trainer
Expand Down Expand Up @@ -98,3 +101,35 @@ def configure_apex(self, amp, model, optimizers, amp_level):
"""
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers

def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
"""
This code is a modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights.
This is important when setting amp_level to O2, and the master weights are in fp16.
Args:
grad_clip_val: Maximum norm of gradients.
optimizer: Optimizer with gradients that will be clipped.
norm_type: (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
"""
model = self.trainer.get_model()
parameters = model.parameters()
max_norm = float(grad_clip_val)

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]

if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef.to(p.grad.device))

@property
def norm_clipping_epsilon(self):
return 1e-5
10 changes: 9 additions & 1 deletion pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin

class NativeAMPPlugin:

class NativeAMPPlugin(PrecisionPlugin):

def __init__(self, trainer=None):
"""
Expand Down Expand Up @@ -51,3 +55,7 @@ def training_step(self, fx, args):
with torch.cuda.amp.autocast():
output = fx(*args)
return output

def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
model = self.trainer.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type)
38 changes: 38 additions & 0 deletions pytorch_lightning/plugins/precision_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Union

from torch.optim import Optimizer


class PrecisionPlugin(abc.ABC):
"""
Abstract class to extend for precision support (32/16 etc).
This is extended to cover any specific logic required for precision support such as AMP/APEX or sharded
training.
"""

def connect(self, model, optimizers):
raise NotImplementedError

def training_step(self, fx, args):
raise NotImplementedError

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
raise NotImplementedError

def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
raise NotImplementedError

0 comments on commit d0b8d10

Please sign in to comment.