Skip to content

Commit

Permalink
Apply code review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNaren committed Nov 12, 2020
1 parent 378c4eb commit 58190e4
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 28 deletions.
9 changes: 5 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Optional, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import move_data_to_device
Expand Down Expand Up @@ -140,12 +141,12 @@ def clip_gradients(self, optimizer, clip_val=None):
return
self._clip_gradients(optimizer, grad_clip_val)

def _clip_gradients(self, optimizer, grad_clip_val):
def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0):
if self.trainer.amp_backend:
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer)
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer, norm_type)
else:
model = self.trainer.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=2.0)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type)

def on_train_epoch_end(self, outputs):
pass
Expand All @@ -166,7 +167,7 @@ def setup_optimizers(self, model):
self.trainer.optimizer_frequencies = optimizer_frequencies

def init_ddp_connection(
self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
) -> None:
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
Expand Down
17 changes: 7 additions & 10 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import torch.multiprocessing as mp
from torch.optim import Optimizer

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
Expand Down Expand Up @@ -262,26 +263,22 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
using_lbfgs=is_lbfgs
)

def _clip_gradients(self, optimizer, grad_clip_val):
def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0):
# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
model = self.trainer.get_model()
parameters = model.parameters()
max_norm = grad_clip_val
norm_type = 2.0

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))

if norm_type == math.inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
device = parameters[0].device
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)
device = parameters[0].device
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
Expand Down
19 changes: 8 additions & 11 deletions pytorch_lightning/plugins/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Tuple
from typing import List, Tuple, Union

import torch
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -102,32 +102,29 @@ def configure_apex(self, amp, model, optimizers, amp_level):
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers

def clip_gradients(self, grad_clip_val, optimizer):
def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
"""
This code is a modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights.
This is important when setting amp_level to O2, and the master weights are in fp16.
Args:
grad_clip_val: Maximum norm of gradients.
optimizer: Optimizer with gradients that will be clipped.
norm_type: (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
"""
model = self.trainer.get_model()
parameters = model.parameters()
max_norm = grad_clip_val
norm_type = 2.0
max_norm = float(grad_clip_val)

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)

if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == math.inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon)
if clip_coef < 1:
for p in parameters:
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin

Expand Down Expand Up @@ -54,6 +56,6 @@ def training_step(self, fx, args):
output = fx(*args)
return output

def clip_gradients(self, grad_clip_val, optimizer):
def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
model = self.trainer.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=2.0)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type)
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Union

from torch.optim import Optimizer


class PrecisionPlugin(abc.ABC):
Expand All @@ -31,5 +34,5 @@ def training_step(self, fx, args):
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
raise NotImplementedError

def clip_gradients(self, grad_clip_val, optimizer):
def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
raise NotImplementedError

0 comments on commit 58190e4

Please sign in to comment.