Skip to content

Commit

Permalink
Add multi_gpu test for ghost clipping
Browse files Browse the repository at this point in the history
Summary: Modify the existing `multigpu_gradcheck.py` test to check gradient correctness for ghost clipping in a distributed setting.

Differential Revision: D60840755
  • Loading branch information
iden-kalemaj authored and facebook-github-bot committed Aug 13, 2024
1 parent f2a591a commit c41dcbb
Showing 1 changed file with 53 additions and 32 deletions.
85 changes: 53 additions & 32 deletions opacus/tests/multigpu_gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import itertools
from opacus import PrivacyEngine
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
from opacus.optimizers.ddp_perlayeroptimizer import (
DistributedPerLayerOptimizer,
SimpleDistributedPerLayerOptimizer,
)
from opacus.optimizers.ddpoptimizer import DistributedDPOptimizer
from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import (
DistributedDPOptimizerFastGradientClipping,
)
from opacus.utils.fast_gradient_clipping_utils import double_backward
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -84,8 +89,10 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):

dataset = TensorDataset(data, labels)

loss_fn = nn.MSELoss()
if dp and clipping == "flat":
reduction = "none" if dp and clipping == "ghost" else "mean"
loss_fn = nn.CrossEntropyLoss(reduction=reduction)

if dp and clipping in ["flat", "ghost"]:
ddp_model = DPDDP(model)
else:
ddp_model = DDP(model, device_ids=[rank])
Expand Down Expand Up @@ -115,15 +122,24 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
optimizer,
(DistributedPerLayerOptimizer, SimpleDistributedPerLayerOptimizer),
)
elif clipping == "ghost":
assert isinstance(optimizer, DistributedDPOptimizerFastGradientClipping)
else:
assert isinstance(optimizer, DistributedDPOptimizer)

for x, y in data_loader:
outputs = ddp_model(x.to(rank))
loss = loss_fn(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if dp and clipping == "ghost":
ddp_model.enable_hooks()
outputs = ddp_model(x.to(rank))
loss_per_sample = loss_fn(outputs, y)
double_backward(ddp_model, optimizer, loss_per_sample)
optimizer.step()
else:
outputs = ddp_model(x.to(rank))
loss = loss_fn(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
break

weight.copy_(model.net1.weight.data.cpu())
Expand All @@ -141,33 +157,38 @@ def run_demo(demo_fn, weight, world_size, dp, clipping, grad_sample_mode):

class GradientComputationTest(unittest.TestCase):
def test_gradient_correct(self) -> None:
# Tests that gradient is the same with DP or with DDP
# Tests that gradient is the same with DP or without DDP
n_gpus = torch.cuda.device_count()
self.assertTrue(
n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}."
)

for clipping in ["flat", "per_layer"]:
for grad_sample_mode in ["hooks", "ew"]:
weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)

run_demo(
demo_basic,
weight_dp,
2,
dp=True,
clipping=clipping,
grad_sample_mode=grad_sample_mode,
)
run_demo(
demo_basic,
weight_nodp,
2,
dp=False,
clipping=None,
grad_sample_mode=None,
)

self.assertTrue(
torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)
)
clipping_grad_sample_pairs = list(
itertools.product(["flat", "per_layer"], ["hooks", "ew"])
)
clipping_grad_sample_pairs.append(("ghost", "ghost"))

for clipping, grad_sample_mode in clipping_grad_sample_pairs:

weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)

run_demo(
demo_basic,
weight_dp,
2,
dp=True,
clipping=clipping,
grad_sample_mode=grad_sample_mode,
)
run_demo(
demo_basic,
weight_nodp,
2,
dp=False,
clipping=None,
grad_sample_mode=None,
)

self.assertTrue(
torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)
)

0 comments on commit c41dcbb

Please sign in to comment.