diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index 43f206e7..dd4be4db 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -14,6 +14,7 @@ # limitations under the License. import os import warnings +from itertools import chain from typing import IO, Any, BinaryIO, Dict, List, Optional, Tuple, Union import torch @@ -383,19 +384,14 @@ def make_private( raise ValueError("Passing seed is prohibited in secure mode") # compare module parameter with optimizer parameters - if not all( - torch.eq(i, j).all() - for i, j in zip( - list(module.parameters()), - sum( - [param_group["params"] for param_group in optimizer.param_groups], - [], - ), - ) + model_parameters = set(module.parameters()) + for p in chain.from_iterable( + [param_group["params"] for param_group in optimizer.param_groups] ): - raise ValueError( - "Module parameters are different than optimizer Parameters" - ) + if p not in model_parameters: + raise ValueError( + "Module parameters are different than optimizer Parameters" + ) distributed = isinstance(module, (DPDDP, DDP)) diff --git a/opacus/tests/privacy_engine_test.py b/opacus/tests/privacy_engine_test.py index 488e8a90..90af717a 100644 --- a/opacus/tests/privacy_engine_test.py +++ b/opacus/tests/privacy_engine_test.py @@ -15,6 +15,7 @@ import abc import io +import itertools import math import unittest from abc import ABC @@ -81,9 +82,16 @@ def _init_model(self): def _init_vanilla_training( self, state_dict: Optional[OrderedDict[str, torch.Tensor]] = None, + opt_exclude_frozen=False, ): model = self._init_model() - optimizer = torch.optim.SGD(model.parameters(), lr=self.LR, momentum=0) + optimizer = torch.optim.SGD( + model.parameters() + if not opt_exclude_frozen + else [p for p in model.parameters() if p.requires_grad], + lr=self.LR, + momentum=0, + ) if state_dict: model.load_state_dict(state_dict) dl = self._init_data() @@ -98,10 +106,17 @@ def _init_private_training( poisson_sampling: bool = True, clipping: str = "flat", grad_sample_mode="hooks", + opt_exclude_frozen=False, ): model = self._init_model() model = PrivacyEngine.get_compatible_module(model) - optimizer = torch.optim.SGD(model.parameters(), lr=self.LR, momentum=0) + optimizer = torch.optim.SGD( + model.parameters() + if not opt_exclude_frozen + else [p for p in model.parameters() if p.requires_grad], + lr=self.LR, + momentum=0, + ) if state_dict: model.load_state_dict(state_dict) @@ -179,13 +194,16 @@ def closure(): break def test_basic(self): - model, optimizer, dl, _ = self._init_private_training( - noise_multiplier=1.0, - max_grad_norm=1.0, - poisson_sampling=True, - grad_sample_mode=self.GRAD_SAMPLE_MODE, - ) - self._train_steps(model, optimizer, dl) + for opt_exclude_frozen in [True, False]: + with self.subTest(opt_exclude_frozen=opt_exclude_frozen): + model, optimizer, dl, _ = self._init_private_training( + noise_multiplier=1.0, + max_grad_norm=1.0, + poisson_sampling=True, + grad_sample_mode=self.GRAD_SAMPLE_MODE, + opt_exclude_frozen=opt_exclude_frozen, + ) + self._train_steps(model, optimizer, dl) def _compare_to_vanilla( self, @@ -469,8 +487,11 @@ def test_deterministic_run(self): "Model parameters after deterministic run must match", ) - def test_param_equal_module_optimizer(self): - """Test that the privacy engine raises error if nn.Module parameters are not equal to optimizer parameters""" + def test_validator_weight_update_check(self): + """ + Test that the privacy engine raises error if ModuleValidator.fix(model) is + called after the optimizer is created + """ model = models.densenet121(pretrained=True) num_ftrs = model.classifier.in_features model.classifier = nn.Sequential(nn.Linear(num_ftrs, 10), nn.Sigmoid()) @@ -504,7 +525,32 @@ def test_param_equal_module_optimizer(self): max_grad_norm=1.0, grad_sample_mode=self.GRAD_SAMPLE_MODE, ) - self.assertTrue(1, 1) + + def test_parameters_match(self): + dl = self._init_data() + + m1 = self._init_model() + m2 = self._init_model() + m2.load_state_dict(m1.state_dict()) + # optimizer is initialized with m2 parameters + opt = torch.optim.SGD(m2.parameters(), lr=0.1) + + # the values are the identical + for p1, p2 in zip(m1.parameters(), m2.parameters()): + self.assertTrue(torch.allclose(p1, p2)) + + privacy_engine = PrivacyEngine() + # but model parameters and optimzier parameters must be the same object, + # not just same values + with self.assertRaises(ValueError): + privacy_engine.make_private( + module=m1, + optimizer=opt, + data_loader=dl, + noise_multiplier=1.1, + max_grad_norm=1.0, + grad_sample_mode=self.GRAD_SAMPLE_MODE, + ) @given( noise_scheduler=st.sampled_from([None, StepNoise]), @@ -759,6 +805,28 @@ def _init_model( return SampleConvNet() +class PrivacyEngineConvNetFrozenTest(BasePrivacyEngineTest, unittest.TestCase): + def _init_data(self): + ds = FakeData( + size=self.DATA_SIZE, + image_size=(1, 35, 35), + num_classes=10, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False) + + def _init_model( + self, private=False, state_dict=None, model=None, **privacy_engine_kwargs + ): + m = SampleConvNet() + for p in itertools.chain(m.conv1.parameters(), m.gnorm1.parameters()): + p.requires_grad = False + + return m + + @unittest.skipIf( torch.__version__ < API_CUTOFF_VERSION, "not supported in this torch version" )