From 0edb702c11e90b31c4eebf1709d12c041ff1284e Mon Sep 17 00:00:00 2001 From: Enayat Ullah Date: Thu, 22 Aug 2024 12:01:34 -0700 Subject: [PATCH] Towards making the interface of ghost clipping same as that of PyTorch (#668) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/668 We define two classes DPLossFastGradientClipping and DPTensorFastGradientClipping in the utils fine, which allows us to repurpose loss.backward() to perform the two backward passes and loss scaling required to implement ghost clipping. Differential Revision: D61162530 --- ...ad_sample_module_fast_gradient_clipping.py | 25 +++-- opacus/optimizers/__init__.py | 15 ++- opacus/privacy_engine.py | 9 ++ ...mple_module_fast_gradient_clipping_test.py | 2 +- opacus/tests/multigpu_gradcheck.py | 2 +- opacus/utils/fast_gradient_clipping_utils.py | 94 ++++++++++++++++++- 6 files changed, 131 insertions(+), 16 deletions(-) diff --git a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py index c86f28c6..8e23b9b3 100644 --- a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py +++ b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py @@ -47,12 +47,21 @@ def create_norm_sample( """ if param.requires_grad: - param._norm_sample = torch.zeros( - torch.Size([max_batch_len, 1]), - device=grad_sample.device, - dtype=grad_sample.dtype, - ) - param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(2, dim=-1) + if ( + max_batch_len == 0 + ): # To handle the case of empty batch that may arise from Poisson sampling + param._norm_sample = torch.tensor( + [], device=grad_sample.device, dtype=grad_sample.dtype + ) + else: + param._norm_sample = torch.zeros( + torch.Size([max_batch_len, 1]), + device=grad_sample.device, + dtype=grad_sample.dtype, + ) + param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm( + 2, dim=-1 + ) class GradSampleModuleFastGradientClipping(GradSampleModule): @@ -110,7 +119,7 @@ def __init__( self.max_grad_norm = max_grad_norm self.use_ghost_clipping = use_ghost_clipping - def get_coeff(self) -> torch.Tensor: + def get_clipping_coef(self) -> torch.Tensor: """Get per-example gradient scaling factor for clipping.""" norm_sample = self.get_norm_sample() return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0) @@ -175,6 +184,7 @@ def capture_backprops_hook( return backprops = forward_output[0].detach() + activations, backprops = self.rearrange_grad_samples( module=module, backprops=backprops, @@ -216,7 +226,6 @@ def capture_backprops_hook( max_batch_len=module.max_batch_len, ) del p.grad_sample - if len(module.activations) == 0: if hasattr(module, "max_batch_len"): del module.max_batch_len diff --git a/opacus/optimizers/__init__.py b/opacus/optimizers/__init__.py index 5867e127..88f79a8d 100644 --- a/opacus/optimizers/__init__.py +++ b/opacus/optimizers/__init__.py @@ -39,12 +39,17 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str = None): - if clipping == "flat" and distributed is False: + if grad_sample_mode == "ghost": + if clipping == "flat" and distributed is False: + return DPOptimizerFastGradientClipping + elif clipping == "flat" and distributed is True: + return DistributedDPOptimizerFastGradientClipping + else: + raise ValueError( + f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: {grad_sample_mode}" + ) + elif clipping == "flat" and distributed is False: return DPOptimizer - elif clipping == "ghost" and distributed is False: - return DPOptimizerFastGradientClipping - elif clipping == "ghost" and distributed is True: - return DistributedDPOptimizerFastGradientClipping elif clipping == "flat" and distributed is True: return DistributedDPOptimizer elif clipping == "per_layer" and distributed is False: diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index e0789a0e..1af891c4 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -30,6 +30,7 @@ ) from opacus.optimizers import DPOptimizer, get_optimizer_class from opacus.schedulers import _GradClipScheduler, _NoiseScheduler +from opacus.utils.fast_gradient_clipping_utils import DPLossFastGradientClipping from opacus.validators.module_validator import ModuleValidator from torch import nn, optim from torch.nn.parallel import DistributedDataParallel as DDP @@ -277,6 +278,7 @@ def make_private( *, module: nn.Module, optimizer: optim.Optimizer, + criterion=nn.CrossEntropyLoss(), # Added deafult for backward compatibility data_loader: DataLoader, noise_multiplier: float, max_grad_norm: Union[float, List[float]], @@ -400,6 +402,11 @@ def make_private( optimizer.attach_step_hook( self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate) ) + if grad_sample_mode == "ghost": + criterion = DPLossFastGradientClipping( + module, optimizer, criterion, loss_reduction + ) + return module, optimizer, criterion, data_loader return module, optimizer, data_loader @@ -408,6 +415,7 @@ def make_private_with_epsilon( *, module: nn.Module, optimizer: optim.Optimizer, + criterion=nn.CrossEntropyLoss(), # Added deafult for backward compatibility data_loader: DataLoader, target_epsilon: float, target_delta: float, @@ -487,6 +495,7 @@ def make_private_with_epsilon( module=module, optimizer=optimizer, data_loader=data_loader, + criterion=criterion, noise_multiplier=get_noise_multiplier( target_epsilon=target_epsilon, target_delta=target_delta, diff --git a/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py index 394db742..2a5b7277 100644 --- a/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py +++ b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py @@ -165,7 +165,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim): first_loss.backward(retain_graph=True) optimizer_gc.zero_grad() - coeff = self.grad_sample_module.get_coeff() + coeff = self.grad_sample_module.get_clipping_coef() second_loss_per_sample = coeff * first_loss_per_sample second_loss = torch.sum(second_loss_per_sample) self.grad_sample_module.disable_hooks() diff --git a/opacus/tests/multigpu_gradcheck.py b/opacus/tests/multigpu_gradcheck.py index f63b15de..6242d8e1 100644 --- a/opacus/tests/multigpu_gradcheck.py +++ b/opacus/tests/multigpu_gradcheck.py @@ -101,7 +101,7 @@ def run_ghost_clipping_test( loss_per_sample = loss_fn(outputs, y) torch.mean(loss_per_sample).backward(retain_graph=True) optimizer.zero_grad() - rescaled_loss_per_sample = ddp_model.get_coeff() * loss_per_sample + rescaled_loss_per_sample = ddp_model.get_clipping_coef() * loss_per_sample rescaled_loss = torch.sum(rescaled_loss_per_sample) ddp_model.disable_hooks() rescaled_loss.backward() diff --git a/opacus/utils/fast_gradient_clipping_utils.py b/opacus/utils/fast_gradient_clipping_utils.py index 367f2495..ad81e76e 100644 --- a/opacus/utils/fast_gradient_clipping_utils.py +++ b/opacus/utils/fast_gradient_clipping_utils.py @@ -20,6 +20,98 @@ from opacus.optimizers import DPOptimizerFastGradientClipping +class DPTensorFastGradientClipping: + """ + Packages the training loop for Fast Gradient and Ghost Clipping into loss.backward(). + """ + + def __init__( + self, + module: GradSampleModuleFastGradientClipping, + optimizer: DPOptimizerFastGradientClipping, + loss_per_sample: torch.Tensor, + loss_reduction: str = "mean", + ): + """ + + Args: + module: the module to train + optimizer: the optimizer used to train the module + loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1] + + """ + + self.module = module + self.optimizer = optimizer + self.loss_per_sample = loss_per_sample + self.loss_reduction = loss_reduction + + def item(self): + if self.loss_reduction == "mean": + return torch.mean(self.loss_per_sample).detach().item() + elif self.loss_reduction == "sum": + return torch.sum(self.loss_per_sample).detach().item() + + def backward(self): + "Repurposes loss.backward() to perform two backward passes, as well as the loss rescaling and hook operations in between" + if self.loss_reduction == "mean": + reduced_loss = torch.mean(self.loss_per_sample) + elif self.loss_reduction == "sum": + reduced_loss = torch.sum(self.loss_per_sample) + else: + raise ValueError( + f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported" + ) + reduced_loss.backward(retain_graph=True) + self.optimizer.zero_grad() + coeff = self.module.get_clipping_coef() + second_loss_per_sample = coeff * self.loss_per_sample + second_loss = torch.sum(second_loss_per_sample) + self.module.disable_hooks() + second_loss.backward() + self.module.enable_hooks() + + +class DPLossFastGradientClipping: + """ + Wrapper on the loss function to be used with Fast Gradient and Ghost Clipping. It computes the per-sample loss, and wraps it in DPTensorFastGradientClipping. + """ + + def __init__( + self, + module: GradSampleModuleFastGradientClipping, + optimizer: DPOptimizerFastGradientClipping, + criterion, + loss_reduction: str = "mean", + ): + assert loss_reduction in [ + "mean", + "sum", + ], "loss_reduction should be either 'mean' or 'sum'" + assert ( + loss_reduction + == criterion.reduction + == module.loss_reduction + == optimizer.loss_reduction + ), "loss_reduction should be the same across GradSampleModule, Optimizer, Criterion, and loss_reduction" + + self.optimizer = optimizer + self.module = module + self.criterion = criterion + self.loss_reduction = loss_reduction + self.criterion.reduction = "none" + + def __call__(self, input, target) -> DPTensorFastGradientClipping: + """ " Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping""" + loss_per_sample = self.criterion( + input, + target, + ) + return DPTensorFastGradientClipping( + self.module, self.optimizer, loss_per_sample, self.loss_reduction + ) + + def double_backward( module: GradSampleModuleFastGradientClipping, optimizer: DPOptimizerFastGradientClipping, @@ -40,7 +132,7 @@ def double_backward( torch.mean(loss_per_sample).backward(retain_graph=True) optimizer.zero_grad() - rescaled_loss_per_sample = module.get_coeff() * loss_per_sample + rescaled_loss_per_sample = module.get_clipping_coef() * loss_per_sample rescaled_loss = torch.sum(rescaled_loss_per_sample) module.disable_hooks() rescaled_loss.backward()