Skip to content

Commit

Permalink
Add multi_gpu test for ghost clipping (#665)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #665

Modify the existing `multigpu_gradcheck.py` test to check gradient correctness for ghost clipping in a distributed setting.

Differential Revision: D60840755
  • Loading branch information
iden-kalemaj authored and facebook-github-bot committed Aug 15, 2024
1 parent f2a591a commit a83f1d2
Showing 1 changed file with 88 additions and 28 deletions.
116 changes: 88 additions & 28 deletions opacus/tests/multigpu_gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import os
import sys
import unittest
Expand All @@ -24,11 +25,16 @@
import torch.optim as optim
from opacus import PrivacyEngine
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
from opacus.grad_sample import GradSampleModuleFastGradientClipping
from opacus.optimizers.ddp_perlayeroptimizer import (
DistributedPerLayerOptimizer,
SimpleDistributedPerLayerOptimizer,
)
from opacus.optimizers.ddpoptimizer import DistributedDPOptimizer
from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import (
DistributedDPOptimizerFastGradientClipping,
)
from opacus.utils.fast_gradient_clipping_utils import double_backward
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -69,6 +75,45 @@ def forward(self, x):
return self.net2(self.relu(self.net1(x)))


def run_ghost_clipping_test(
model, optimizer, data_loader, batch_size, max_grad_norm, weight, rank
):

ddp_model = DPDDP(model)
ddp_model = GradSampleModuleFastGradientClipping(
ddp_model,
max_grad_norm=max_grad_norm,
use_ghost_clipping=True,
)
optimizer = DistributedDPOptimizerFastGradientClipping(
optimizer,
noise_multiplier=0,
max_grad_norm=max_grad_norm,
expected_batch_size=batch_size,
)

assert isinstance(optimizer, DistributedDPOptimizerFastGradientClipping)

loss_fn = nn.CrossEntropyLoss(reduction="none")

for x, y in data_loader:
ddp_model.enable_hooks()
outputs = ddp_model(x.to(rank))
loss_per_sample = loss_fn(outputs, y)
torch.mean(loss_per_sample).backward(retain_graph=True)
optimizer.zero_grad()
rescaled_loss_per_sample = ddp_model.get_coeff() * loss_per_sample
rescaled_loss = torch.sum(rescaled_loss_per_sample)
ddp_model.disable_hooks()
rescaled_loss.backward()
ddp_model.enable_hooks()
optimizer.step()
break

weight.copy_(model.net1.weight.data.cpu())
cleanup()


def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
torch.manual_seed(world_size)
batch_size = 32
Expand All @@ -79,12 +124,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
model.net1.weight.data.zero_()
optimizer = optim.SGD(model.parameters(), lr=1)

# create dataset
labels = torch.randn(2 * batch_size, 5).to(rank)
data = torch.randn(2 * batch_size, 10)

dataset = TensorDataset(data, labels)

loss_fn = nn.MSELoss()
loss_fn = nn.CrossEntropyLoss()

max_grad_norm = 1e8

if dp and clipping == "flat":
ddp_model = DPDDP(model)
else:
Expand All @@ -96,8 +144,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
dataset, num_replicas=world_size, rank=rank, shuffle=False
)
data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

# use a separate function for ghost clipping since the procedure has a different structure
if dp and clipping == "ghost":
run_ghost_clipping_test(
model, optimizer, data_loader, batch_size, max_grad_norm, weight, rank
)
return

if dp:
max_grad_norm = 1e8
if clipping == "per_layer":
max_grad_norm = [max_grad_norm for _ in model.parameters()]
ddp_model, optimizer, data_loader = privacy_engine.make_private(
Expand Down Expand Up @@ -141,33 +196,38 @@ def run_demo(demo_fn, weight, world_size, dp, clipping, grad_sample_mode):

class GradientComputationTest(unittest.TestCase):
def test_gradient_correct(self) -> None:
# Tests that gradient is the same with DP or with DDP
# Tests that gradient is the same with DP or without DDP
n_gpus = torch.cuda.device_count()
self.assertTrue(
n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}."
)

for clipping in ["flat", "per_layer"]:
for grad_sample_mode in ["hooks", "ew"]:
weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)

run_demo(
demo_basic,
weight_dp,
2,
dp=True,
clipping=clipping,
grad_sample_mode=grad_sample_mode,
)
run_demo(
demo_basic,
weight_nodp,
2,
dp=False,
clipping=None,
grad_sample_mode=None,
)

self.assertTrue(
torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)
)
clipping_grad_sample_pairs = list(
itertools.product(["flat", "per_layer"], ["hooks", "ew"])
)
clipping_grad_sample_pairs.append(("ghost", "ghost"))

for clipping, grad_sample_mode in clipping_grad_sample_pairs:

weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)

run_demo(
demo_basic,
weight_dp,
2,
dp=True,
clipping=clipping,
grad_sample_mode=grad_sample_mode,
)
run_demo(
demo_basic,
weight_nodp,
2,
dp=False,
clipping=None,
grad_sample_mode=None,
)

self.assertTrue(
torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)
)

0 comments on commit a83f1d2

Please sign in to comment.