Skip to content

Commit

Permalink
Fast Gradient and Ghost Clipping (pytorch#656)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#656

Itroducing Fast Gradient Clipping and Ghost Clipping to Opacus for memory-efficient training with DP SGD.

Reviewed By: HuanyuZhang

Differential Revision: D58210796
  • Loading branch information
EnayatUllah authored and facebook-github-bot committed Jul 21, 2024
1 parent 1235e1e commit dc162b4
Show file tree
Hide file tree
Showing 10 changed files with 856 additions and 13 deletions.
11 changes: 8 additions & 3 deletions opacus/grad_sample/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,28 @@
from .dp_multihead_attention import compute_sequence_bias_grad_sample # noqa
from .dp_rnn import compute_rnn_linear_grad_sample # noqa
from .embedding import compute_embedding_grad_sample # noqa
from .grad_sample_module import GradSampleModule, create_or_accumulate_grad_sample
from .grad_sample_module import (GradSampleModule,
create_or_accumulate_grad_sample)
from .grad_sample_module_fast_gradient_clipping import \
GradSampleModuleFastGradientClipping # noqa
from .group_norm import compute_group_norm_grad_sample # noqa
from .gsm_base import AbstractGradSampleModule
from .gsm_exp_weights import GradSampleModuleExpandedWeights
from .gsm_no_op import GradSampleModuleNoOp
from .instance_norm import compute_instance_norm_grad_sample # noqa
from .layer_norm import compute_layer_norm_grad_sample # noqa
from .linear import compute_linear_grad_sample # noqa
from .utils import get_gsm_class, register_grad_sampler, wrap_model

from .utils import (get_gsm_class, register_grad_sampler,
register_norm_sampler, wrap_model)

__all__ = [
"GradSampleModule",
"GradSampleModuleFastGradientClipping",
"GradSampleModuleExpandedWeights",
"GradSampleModuleNoOp",
"AbstractGradSampleModule",
"register_grad_sampler",
"register_norm_sampler",
"create_or_accumulate_grad_sample",
"wrap_model",
"get_gsm_class",
Expand Down
5 changes: 2 additions & 3 deletions opacus/grad_sample/grad_sample_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@


logger = logging.getLogger(__name__)
logger.disabled = True


def create_or_accumulate_grad_sample(
Expand Down Expand Up @@ -465,10 +466,8 @@ def validate(
errors.extend(
[
NotImplementedError(
f"Model contains a trainable layer "
f"Model contains a trainable layer with buffers"
f"that Opacus doesn't currently support({m_name}:{m}). "
f"Please implement and register grad sampler for this layer. "
f"(See opacus.grad_sample.utils.register_grad_sampler)"
)
for m_name, m in trainable_modules(module)
# With functorch, all modules are trainable
Expand Down
222 changes: 222 additions & 0 deletions opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
from typing import List

import torch
import torch.nn as nn
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient
from opacus.grad_sample.grad_sample_module import (
GradSampleModule,
create_or_accumulate_grad_sample,
promote_current_grad_sample,
)
from opacus.utils.module_utils import requires_grad, trainable_parameters


logger = logging.getLogger(__name__)
logger.disabled = True


def create_norm_sample(
*, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int
) -> None:
"""
Creates a ``_norm_sample`` attribute in the given parameter
Args:
param: Parameter to which ``_norm_sample`` will be added
grad_sample: Per-sample gradients tensor. Must be of the same
shape as ``param`` with extra batch dimension
"""

if param.requires_grad:
param._norm_sample = torch.zeros(
torch.Size([max_batch_len, 1]),
device=grad_sample.device,
dtype=grad_sample.dtype,
)
param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(2, dim=-1)


class GradSampleModuleFastGradientClipping(GradSampleModule):
"""
Hooks-based implementation of GradSampleModule with Fast Gradient and Ghost Clipping
Computes norms of gradients without gradient instantiation
"""

NORM_SAMPLERS = {}

def __init__(
self,
m: nn.Module,
*,
batch_first=True,
loss_reduction="mean",
strict: bool = True,
force_functorch=False,
max_grad_norm=1,
use_ghost_clipping=True,
):
"""
Args:
m: nn.Module to be wrapped
batch_first: Flag to indicate if the input tensor to the corresponding module
has the first dimension representing the batch. If set to True, dimensions on
input tensor are expected be ``[batch_size, ...]``, otherwise
``[K, batch_size, ...]``
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
is a sum or a mean operation. Can take values "sum" or "mean"
max_grad_norm: The value at which gradients are to be clipped.
strict: If set to True, the input module will be validated to make sure that
it does not have buffers in all its submodules.
force_functorch: If set to ``True``, will use functorch to compute
all per sample gradients. Otherwise, functorch will be used only
for layers without registered grad sampler methods.
use_ghost_clipping: If set to ``True``, Ghost Clipping
will be used for clipping gradients of supported layers. If ``False``, Fast
Gradient Clipping will be used for all layers.
Raises:
NotImplementedError
If ``strict`` is set to ``True`` and module ``m`` (or any of its
submodules) doesn't have a registered grad sampler function.
"""

super().__init__(
m,
batch_first=batch_first,
loss_reduction=loss_reduction,
)
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
self.max_grad_norm = max_grad_norm
self.use_ghost_clipping = use_ghost_clipping

def get_coeff(self) -> torch.Tensor:
"""Get per-example gradient scaling factor for clipping."""
norm_sample = self.get_norm_sample()
return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0)

def get_norm_sample(self) -> torch.Tensor:
"""Get per-example gradient norms."""
norm_sample = torch.stack(
[param._norm_sample for param in self.trainable_parameters], dim=0
).norm(2, dim=0)
return norm_sample

def capture_activations_hook(
self,
module: nn.Module,
forward_input: List[torch.Tensor],
_forward_output: torch.Tensor,
):
if (
not requires_grad(module)
or not module.training
or not torch.is_grad_enabled()
or not self.hooks_enabled
):
return

if not hasattr(module, "activations"):
module.activations = []
module.activations.append([t.detach() for t in forward_input]) # pyre-ignore

for _, p in trainable_parameters(module):
p._forward_counter += 1
if (
self.use_ghost_clipping
and p._forward_counter > 1
and type(module) in self.NORM_SAMPLERS
):
raise NotImplementedError(
"Parameter tying is not supported with Ghost Clipping"
)

def capture_backprops_hook(
self,
module: nn.Module,
_forward_input: torch.Tensor,
forward_output: torch.Tensor,
loss_reduction: str,
batch_first: bool,
):
"""
Computes norms of per sample gradient given the current backprops and activations
stored by the associated forward hook. Computed per sample gradient norms are
stored in ``norm_sample`` field in each parameter.
Args:
module: nn.Module,
_forward_input: torch.Tensor,
forward_output: torch.Tensor,
loss_reduction: str,
batch_first: bool,
"""
if not self.hooks_enabled:
return

backprops = forward_output[0].detach()
activations, backprops = self.rearrange_grad_samples(
module=module,
backprops=backprops,
loss_reduction=loss_reduction,
batch_first=batch_first,
)

if self.use_ghost_clipping and type(module) in self.NORM_SAMPLERS:
norm_sampler_fn = self.NORM_SAMPLERS[type(module)]
norm_samples = norm_sampler_fn(module, activations, backprops)

for param, ns in norm_samples.items():
if param.requires_grad:
param._norm_sample = ns
param._forward_counter -= 1

else:
if not self.force_functorch and type(module) in self.GRAD_SAMPLERS:
grad_sampler_fn = self.GRAD_SAMPLERS[type(module)]
else:
grad_sampler_fn = ft_compute_per_sample_gradient

grad_samples = grad_sampler_fn(module, activations, backprops)
for param, gs in grad_samples.items():
create_or_accumulate_grad_sample(
param=param, grad_sample=gs, max_batch_len=module.max_batch_len
)
del grad_samples
# Detect end of current batch processing and switch accumulation
# mode from sum to stacking. Used for RNNs and tied parameters
# (See #417 for details)
for _, p in trainable_parameters(module):
p._forward_counter -= 1
if p._forward_counter == 0:
promote_current_grad_sample(p)
create_norm_sample(
param=p,
grad_sample=p.grad_sample,
max_batch_len=module.max_batch_len,
)
del p.grad_sample

if len(module.activations) == 0:
if hasattr(module, "max_batch_len"):
del module.max_batch_len
46 changes: 45 additions & 1 deletion opacus/grad_sample/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, List

import torch
import torch.nn as nn
from opt_einsum import contract

from .utils import register_grad_sampler
from .utils import register_grad_sampler, register_norm_sampler


logger = logging.getLogger(__name__)
logging.disabled = False


@register_grad_sampler(nn.Linear)
Expand All @@ -42,3 +47,42 @@ def compute_linear_grad_sample(
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = contract("n...k->nk", backprops)
return ret


@register_norm_sampler(nn.Linear)
def compute_linear_norm_sample(
layer: nn.Linear, activations: List[torch.Tensor], backprops: torch.Tensor
) -> Dict[nn.Parameter, torch.Tensor]:
"""
Computes per sample gradient norms for ``nn.Linear`` layer
Args:
layer: Layer
activations: Activations
backprops: Backpropagations
"""
activations = activations[0]
ret = {}

if backprops.dim() == 2:
if layer.weight.requires_grad:
g = contract("n...i,n...i->n", backprops, backprops)
a = contract("n...j,n...j->n", activations, activations)
ret[layer.weight] = torch.sqrt((g * a).flatten())
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = torch.sqrt(
contract("n...i,n...i->n", backprops, backprops).flatten()
)
elif backprops.dim() == 3:
if layer.weight.requires_grad:

ggT = contract("nik,njk->nij", backprops, backprops) # batchwise g g^T
aaT = contract("nik,njk->nij", activations, activations) # batchwise a a^T
ga = contract("n...i,n...i->n", ggT, aaT).clamp(min=0)

ret[layer.weight] = torch.sqrt(ga)
if layer.bias is not None and layer.bias.requires_grad:
ggT = contract("nik,njk->nij", backprops, backprops)
gg = contract("n...i,n...i->n", ggT, ggT).clamp(min=0)
ret[layer.bias] = torch.sqrt(gg)
return ret
34 changes: 33 additions & 1 deletion opacus/grad_sample/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python3
# !/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -18,6 +18,9 @@
import torch.nn as nn

from .grad_sample_module import GradSampleModule
from .grad_sample_module_fast_gradient_clipping import (
GradSampleModuleFastGradientClipping,
)
from .gsm_base import AbstractGradSampleModule
from .gsm_exp_weights import GradSampleModuleExpandedWeights
from .gsm_no_op import GradSampleModuleNoOp
Expand Down Expand Up @@ -46,6 +49,33 @@ def decorator(f):
)
for target_class in target_classes:
GradSampleModule.GRAD_SAMPLERS[target_class] = f
GradSampleModuleFastGradientClipping.GRAD_SAMPLERS[target_class] = f
return f

return decorator


def register_norm_sampler(
target_class_or_classes: Union[Type[nn.Module], Sequence[Type[nn.Module]]]
):
"""
Registers the decorated function as the ``norm_sampler`` of ``target_class_or_classes``, which is
the function that will be invoked every time you want to compute a per-sample gradient norm
of ``target_class_or_classes``. The signature of every norm_sampler is always the same:
>>> @register_norm_sampler(MyCustomModel)
... def compute_grad_norm_sample(module, activations, backprops):
... pass
"""

def decorator(f):
target_classes = (
target_class_or_classes
if isinstance(target_class_or_classes, Sequence)
else [target_class_or_classes]
)
for target_class in target_classes:
GradSampleModuleFastGradientClipping.NORM_SAMPLERS[target_class] = f
return f

return decorator
Expand All @@ -70,6 +100,8 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]:
return GradSampleModule
elif grad_sample_mode == "ew":
return GradSampleModuleExpandedWeights
elif grad_sample_mode == "ghost":
return GradSampleModuleFastGradientClipping
elif grad_sample_mode == "no_op":
return GradSampleModuleNoOp
else:
Expand Down
Loading

0 comments on commit dc162b4

Please sign in to comment.