diff --git a/opacus/tests/multigpu_gradcheck.py b/opacus/tests/multigpu_gradcheck.py index af4e7bfe..f63b15de 100644 --- a/opacus/tests/multigpu_gradcheck.py +++ b/opacus/tests/multigpu_gradcheck.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import os import sys import unittest @@ -24,11 +25,15 @@ import torch.optim as optim from opacus import PrivacyEngine from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP +from opacus.grad_sample import GradSampleModuleFastGradientClipping from opacus.optimizers.ddp_perlayeroptimizer import ( DistributedPerLayerOptimizer, SimpleDistributedPerLayerOptimizer, ) from opacus.optimizers.ddpoptimizer import DistributedDPOptimizer +from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import ( + DistributedDPOptimizerFastGradientClipping, +) from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, TensorDataset from torch.utils.data.distributed import DistributedSampler @@ -69,6 +74,45 @@ def forward(self, x): return self.net2(self.relu(self.net1(x))) +def run_ghost_clipping_test( + model, optimizer, data_loader, batch_size, max_grad_norm, weight, rank +): + + ddp_model = DPDDP(model) + ddp_model = GradSampleModuleFastGradientClipping( + ddp_model, + max_grad_norm=max_grad_norm, + use_ghost_clipping=True, + ) + optimizer = DistributedDPOptimizerFastGradientClipping( + optimizer, + noise_multiplier=0, + max_grad_norm=max_grad_norm, + expected_batch_size=batch_size, + ) + + assert isinstance(optimizer, DistributedDPOptimizerFastGradientClipping) + + loss_fn = nn.CrossEntropyLoss(reduction="none") + + for x, y in data_loader: + ddp_model.enable_hooks() + outputs = ddp_model(x.to(rank)) + loss_per_sample = loss_fn(outputs, y) + torch.mean(loss_per_sample).backward(retain_graph=True) + optimizer.zero_grad() + rescaled_loss_per_sample = ddp_model.get_coeff() * loss_per_sample + rescaled_loss = torch.sum(rescaled_loss_per_sample) + ddp_model.disable_hooks() + rescaled_loss.backward() + ddp_model.enable_hooks() + optimizer.step() + break + + weight.copy_(model.net1.weight.data.cpu()) + cleanup() + + def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode): torch.manual_seed(world_size) batch_size = 32 @@ -79,12 +123,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode): model.net1.weight.data.zero_() optimizer = optim.SGD(model.parameters(), lr=1) + # create dataset labels = torch.randn(2 * batch_size, 5).to(rank) data = torch.randn(2 * batch_size, 10) - dataset = TensorDataset(data, labels) - loss_fn = nn.MSELoss() + loss_fn = nn.CrossEntropyLoss() + + max_grad_norm = 1e8 + if dp and clipping == "flat": ddp_model = DPDDP(model) else: @@ -96,8 +143,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode): dataset, num_replicas=world_size, rank=rank, shuffle=False ) data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler) + + # use a separate function for ghost clipping since the procedure has a different structure + if dp and clipping == "ghost": + run_ghost_clipping_test( + model, optimizer, data_loader, batch_size, max_grad_norm, weight, rank + ) + return + if dp: - max_grad_norm = 1e8 if clipping == "per_layer": max_grad_norm = [max_grad_norm for _ in model.parameters()] ddp_model, optimizer, data_loader = privacy_engine.make_private( @@ -141,33 +195,38 @@ def run_demo(demo_fn, weight, world_size, dp, clipping, grad_sample_mode): class GradientComputationTest(unittest.TestCase): def test_gradient_correct(self) -> None: - # Tests that gradient is the same with DP or with DDP + # Tests that gradient is the same with DP or without DDP n_gpus = torch.cuda.device_count() self.assertTrue( n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}." ) - for clipping in ["flat", "per_layer"]: - for grad_sample_mode in ["hooks", "ew"]: - weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10) - - run_demo( - demo_basic, - weight_dp, - 2, - dp=True, - clipping=clipping, - grad_sample_mode=grad_sample_mode, - ) - run_demo( - demo_basic, - weight_nodp, - 2, - dp=False, - clipping=None, - grad_sample_mode=None, - ) - - self.assertTrue( - torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3) - ) + clipping_grad_sample_pairs = list( + itertools.product(["flat", "per_layer"], ["hooks", "ew"]) + ) + clipping_grad_sample_pairs.append(("ghost", "ghost")) + + for clipping, grad_sample_mode in clipping_grad_sample_pairs: + + weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10) + + run_demo( + demo_basic, + weight_dp, + 2, + dp=True, + clipping=clipping, + grad_sample_mode=grad_sample_mode, + ) + run_demo( + demo_basic, + weight_nodp, + 2, + dp=False, + clipping=None, + grad_sample_mode=None, + ) + + self.assertTrue( + torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3) + )