Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded Accelerator 1/n: Expose clip gradients to plugins via abstract class #4639

Merged
merged 10 commits into from
Nov 12, 2020
25 changes: 10 additions & 15 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,20 @@

import torch

from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log

try:
from apex import amp
except ImportError:
amp = None

if torch.distributed.is_available():
from torch.distributed import ReduceOp
else:
class ReduceOp:
SUM = None

EPSILON = 1e-6
EPSILON_FP16 = 1e-5


class Accelerator(object):
Expand Down Expand Up @@ -149,17 +143,19 @@ def _clip_gradients(self, optimizer, clip_val=None):
grad_clip_val = clip_val
grad_clip_val = float(grad_clip_val)

# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
if grad_clip_val <= 0:
return

model = self.trainer.get_model()
if self.trainer.amp_backend == AMPType.APEX:
parameters = amp.master_params(optimizer)
if self.trainer.amp_backend:
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer)
else:
parameters = model.parameters()
self._clip_gradients_with_tpu_support(grad_clip_val)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def _clip_gradients_with_tpu_support(self, grad_clip_val):
# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
model = self.trainer.get_model()
parameters = model.parameters()
max_norm = grad_clip_val
norm_type = float(2.0)

Expand All @@ -176,8 +172,7 @@ def _clip_gradients(self, optimizer, clip_val=None):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

eps = EPSILON_FP16 if self.trainer.precision == 16 else EPSILON
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + EPSILON)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/plugins/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.
from typing import List, Tuple

import torch
from torch.optim.optimizer import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities import AMPType

Expand All @@ -25,7 +27,7 @@
amp = None


class ApexPlugin:
class ApexPlugin(PrecisionPlugin):

def __init__(self, trainer=None):
self.trainer = trainer
Expand Down Expand Up @@ -98,3 +100,9 @@ def configure_apex(self, amp, model, optimizers, amp_level):
"""
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers

def clip_gradients(self, grad_clip_val, model, optimizer):
parameters = amp.master_params(optimizer)
max_norm = grad_clip_val
norm_type = float(2.0)
torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type)
9 changes: 8 additions & 1 deletion pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import torch

from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin

class NativeAMPPlugin:

class NativeAMPPlugin(PrecisionPlugin):

def __init__(self, trainer=None):
"""
Expand Down Expand Up @@ -51,3 +53,8 @@ def training_step(self, fx, args):
with torch.cuda.amp.autocast():
output = fx(*args)
return output

def clip_gradients(self, grad_clip_val, model, optimizer):
max_norm = grad_clip_val
norm_type = float(2.0)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm, norm_type=norm_type)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
22 changes: 22 additions & 0 deletions pytorch_lightning/plugins/precision_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import abc
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved


class PrecisionPlugin(abc.ABC):
"""
Abstract class to extend for precision support (32/16 etc).

This is extended to cover any specific logic required for precision support such as AMP/APEX or sharded
training.
"""

def connect(self, model, optimizers):
raise NotImplementedError

def training_step(self, fx, args):
raise NotImplementedError

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
raise NotImplementedError

def clip_gradients(self, grad_clip_val, model, optimizer):
raise NotImplementedError