diff --git a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py index c86f28c6..8e23b9b3 100644 --- a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py +++ b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py @@ -47,12 +47,21 @@ def create_norm_sample( """ if param.requires_grad: - param._norm_sample = torch.zeros( - torch.Size([max_batch_len, 1]), - device=grad_sample.device, - dtype=grad_sample.dtype, - ) - param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(2, dim=-1) + if ( + max_batch_len == 0 + ): # To handle the case of empty batch that may arise from Poisson sampling + param._norm_sample = torch.tensor( + [], device=grad_sample.device, dtype=grad_sample.dtype + ) + else: + param._norm_sample = torch.zeros( + torch.Size([max_batch_len, 1]), + device=grad_sample.device, + dtype=grad_sample.dtype, + ) + param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm( + 2, dim=-1 + ) class GradSampleModuleFastGradientClipping(GradSampleModule): @@ -110,7 +119,7 @@ def __init__( self.max_grad_norm = max_grad_norm self.use_ghost_clipping = use_ghost_clipping - def get_coeff(self) -> torch.Tensor: + def get_clipping_coef(self) -> torch.Tensor: """Get per-example gradient scaling factor for clipping.""" norm_sample = self.get_norm_sample() return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0) @@ -175,6 +184,7 @@ def capture_backprops_hook( return backprops = forward_output[0].detach() + activations, backprops = self.rearrange_grad_samples( module=module, backprops=backprops, @@ -216,7 +226,6 @@ def capture_backprops_hook( max_batch_len=module.max_batch_len, ) del p.grad_sample - if len(module.activations) == 0: if hasattr(module, "max_batch_len"): del module.max_batch_len diff --git a/opacus/optimizers/__init__.py b/opacus/optimizers/__init__.py index 5867e127..d5c3b104 100644 --- a/opacus/optimizers/__init__.py +++ b/opacus/optimizers/__init__.py @@ -39,12 +39,12 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str = None): - if clipping == "flat" and distributed is False: - return DPOptimizer - elif clipping == "ghost" and distributed is False: + if grad_sample_mode == "ghost" and clipping == "flat" and distributed is False: return DPOptimizerFastGradientClipping - elif clipping == "ghost" and distributed is True: + elif grad_sample_mode == "ghost" and clipping == "flat" and distributed is True: return DistributedDPOptimizerFastGradientClipping + elif clipping == "flat" and distributed is False: + return DPOptimizer elif clipping == "flat" and distributed is True: return DistributedDPOptimizer elif clipping == "per_layer" and distributed is False: diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index e0789a0e..1af891c4 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -30,6 +30,7 @@ ) from opacus.optimizers import DPOptimizer, get_optimizer_class from opacus.schedulers import _GradClipScheduler, _NoiseScheduler +from opacus.utils.fast_gradient_clipping_utils import DPLossFastGradientClipping from opacus.validators.module_validator import ModuleValidator from torch import nn, optim from torch.nn.parallel import DistributedDataParallel as DDP @@ -277,6 +278,7 @@ def make_private( *, module: nn.Module, optimizer: optim.Optimizer, + criterion=nn.CrossEntropyLoss(), # Added deafult for backward compatibility data_loader: DataLoader, noise_multiplier: float, max_grad_norm: Union[float, List[float]], @@ -400,6 +402,11 @@ def make_private( optimizer.attach_step_hook( self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate) ) + if grad_sample_mode == "ghost": + criterion = DPLossFastGradientClipping( + module, optimizer, criterion, loss_reduction + ) + return module, optimizer, criterion, data_loader return module, optimizer, data_loader @@ -408,6 +415,7 @@ def make_private_with_epsilon( *, module: nn.Module, optimizer: optim.Optimizer, + criterion=nn.CrossEntropyLoss(), # Added deafult for backward compatibility data_loader: DataLoader, target_epsilon: float, target_delta: float, @@ -487,6 +495,7 @@ def make_private_with_epsilon( module=module, optimizer=optimizer, data_loader=data_loader, + criterion=criterion, noise_multiplier=get_noise_multiplier( target_epsilon=target_epsilon, target_delta=target_delta, diff --git a/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py index 394db742..2a5b7277 100644 --- a/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py +++ b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py @@ -165,7 +165,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim): first_loss.backward(retain_graph=True) optimizer_gc.zero_grad() - coeff = self.grad_sample_module.get_coeff() + coeff = self.grad_sample_module.get_clipping_coef() second_loss_per_sample = coeff * first_loss_per_sample second_loss = torch.sum(second_loss_per_sample) self.grad_sample_module.disable_hooks() diff --git a/opacus/tests/multigpu_gradcheck.py b/opacus/tests/multigpu_gradcheck.py index f63b15de..6242d8e1 100644 --- a/opacus/tests/multigpu_gradcheck.py +++ b/opacus/tests/multigpu_gradcheck.py @@ -101,7 +101,7 @@ def run_ghost_clipping_test( loss_per_sample = loss_fn(outputs, y) torch.mean(loss_per_sample).backward(retain_graph=True) optimizer.zero_grad() - rescaled_loss_per_sample = ddp_model.get_coeff() * loss_per_sample + rescaled_loss_per_sample = ddp_model.get_clipping_coef() * loss_per_sample rescaled_loss = torch.sum(rescaled_loss_per_sample) ddp_model.disable_hooks() rescaled_loss.backward() diff --git a/opacus/utils/fast_gradient_clipping_utils.py b/opacus/utils/fast_gradient_clipping_utils.py index 367f2495..ad81e76e 100644 --- a/opacus/utils/fast_gradient_clipping_utils.py +++ b/opacus/utils/fast_gradient_clipping_utils.py @@ -20,6 +20,98 @@ from opacus.optimizers import DPOptimizerFastGradientClipping +class DPTensorFastGradientClipping: + """ + Packages the training loop for Fast Gradient and Ghost Clipping into loss.backward(). + """ + + def __init__( + self, + module: GradSampleModuleFastGradientClipping, + optimizer: DPOptimizerFastGradientClipping, + loss_per_sample: torch.Tensor, + loss_reduction: str = "mean", + ): + """ + + Args: + module: the module to train + optimizer: the optimizer used to train the module + loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1] + + """ + + self.module = module + self.optimizer = optimizer + self.loss_per_sample = loss_per_sample + self.loss_reduction = loss_reduction + + def item(self): + if self.loss_reduction == "mean": + return torch.mean(self.loss_per_sample).detach().item() + elif self.loss_reduction == "sum": + return torch.sum(self.loss_per_sample).detach().item() + + def backward(self): + "Repurposes loss.backward() to perform two backward passes, as well as the loss rescaling and hook operations in between" + if self.loss_reduction == "mean": + reduced_loss = torch.mean(self.loss_per_sample) + elif self.loss_reduction == "sum": + reduced_loss = torch.sum(self.loss_per_sample) + else: + raise ValueError( + f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported" + ) + reduced_loss.backward(retain_graph=True) + self.optimizer.zero_grad() + coeff = self.module.get_clipping_coef() + second_loss_per_sample = coeff * self.loss_per_sample + second_loss = torch.sum(second_loss_per_sample) + self.module.disable_hooks() + second_loss.backward() + self.module.enable_hooks() + + +class DPLossFastGradientClipping: + """ + Wrapper on the loss function to be used with Fast Gradient and Ghost Clipping. It computes the per-sample loss, and wraps it in DPTensorFastGradientClipping. + """ + + def __init__( + self, + module: GradSampleModuleFastGradientClipping, + optimizer: DPOptimizerFastGradientClipping, + criterion, + loss_reduction: str = "mean", + ): + assert loss_reduction in [ + "mean", + "sum", + ], "loss_reduction should be either 'mean' or 'sum'" + assert ( + loss_reduction + == criterion.reduction + == module.loss_reduction + == optimizer.loss_reduction + ), "loss_reduction should be the same across GradSampleModule, Optimizer, Criterion, and loss_reduction" + + self.optimizer = optimizer + self.module = module + self.criterion = criterion + self.loss_reduction = loss_reduction + self.criterion.reduction = "none" + + def __call__(self, input, target) -> DPTensorFastGradientClipping: + """ " Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping""" + loss_per_sample = self.criterion( + input, + target, + ) + return DPTensorFastGradientClipping( + self.module, self.optimizer, loss_per_sample, self.loss_reduction + ) + + def double_backward( module: GradSampleModuleFastGradientClipping, optimizer: DPOptimizerFastGradientClipping, @@ -40,7 +132,7 @@ def double_backward( torch.mean(loss_per_sample).backward(retain_graph=True) optimizer.zero_grad() - rescaled_loss_per_sample = module.get_coeff() * loss_per_sample + rescaled_loss_per_sample = module.get_clipping_coef() * loss_per_sample rescaled_loss = torch.sum(rescaled_loss_per_sample) module.disable_hooks() rescaled_loss.backward()