Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better validator for matching parameters #489

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import os
import warnings
from itertools import chain
from typing import IO, Any, BinaryIO, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -383,19 +384,14 @@ def make_private(
raise ValueError("Passing seed is prohibited in secure mode")

# compare module parameter with optimizer parameters
if not all(
torch.eq(i, j).all()
for i, j in zip(
list(module.parameters()),
sum(
[param_group["params"] for param_group in optimizer.param_groups],
[],
),
)
model_parameters = set(module.parameters())
for p in chain.from_iterable(
[param_group["params"] for param_group in optimizer.param_groups]
):
raise ValueError(
"Module parameters are different than optimizer Parameters"
)
if p not in model_parameters:
raise ValueError(
"Module parameters are different than optimizer Parameters"
)

distributed = isinstance(module, (DPDDP, DDP))

Expand Down
96 changes: 84 additions & 12 deletions opacus/tests/privacy_engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import abc
import io
import itertools
import math
import unittest
from abc import ABC
Expand Down Expand Up @@ -81,9 +82,18 @@ def _init_model(self):
def _init_vanilla_training(
self,
state_dict: Optional[OrderedDict[str, torch.Tensor]] = None,
opt_exclude_frozen=False,
):
model = self._init_model()
optimizer = torch.optim.SGD(model.parameters(), lr=self.LR, momentum=0)
optimizer = torch.optim.SGD(
[
p
for p in model.parameters()
if p.requires_grad or (not opt_exclude_frozen)
],
lr=self.LR,
momentum=0,
)
if state_dict:
model.load_state_dict(state_dict)
dl = self._init_data()
Expand All @@ -98,10 +108,19 @@ def _init_private_training(
poisson_sampling: bool = True,
clipping: str = "flat",
grad_sample_mode="hooks",
opt_exclude_frozen=False,
):
model = self._init_model()
model = PrivacyEngine.get_compatible_module(model)
optimizer = torch.optim.SGD(model.parameters(), lr=self.LR, momentum=0)
optimizer = torch.optim.SGD(
[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest something like this for readability

model.parameters() 
if not opt_exclude_frozen 
else [p for p in model.parameters() if p.requires_grad]

Same in the _init_vanilla_training

p
for p in model.parameters()
if p.requires_grad or (not opt_exclude_frozen)
],
lr=self.LR,
momentum=0,
)

if state_dict:
model.load_state_dict(state_dict)
Expand Down Expand Up @@ -179,13 +198,16 @@ def closure():
break

def test_basic(self):
model, optimizer, dl, _ = self._init_private_training(
noise_multiplier=1.0,
max_grad_norm=1.0,
poisson_sampling=True,
grad_sample_mode=self.GRAD_SAMPLE_MODE,
)
self._train_steps(model, optimizer, dl)
for opt_exclude_frozen in [True, False]:
with self.subTest(opt_exclude_frozen=opt_exclude_frozen):
model, optimizer, dl, _ = self._init_private_training(
noise_multiplier=1.0,
max_grad_norm=1.0,
poisson_sampling=True,
grad_sample_mode=self.GRAD_SAMPLE_MODE,
opt_exclude_frozen=opt_exclude_frozen,
)
self._train_steps(model, optimizer, dl)

def _compare_to_vanilla(
self,
Expand Down Expand Up @@ -469,8 +491,11 @@ def test_deterministic_run(self):
"Model parameters after deterministic run must match",
)

def test_param_equal_module_optimizer(self):
"""Test that the privacy engine raises error if nn.Module parameters are not equal to optimizer parameters"""
def test_validator_weight_update_check(self):
"""
Test that the privacy engine raises error if ModuleValidator.fix(model) is
called after the optimizer is created
"""
model = models.densenet121(pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Sequential(nn.Linear(num_ftrs, 10), nn.Sigmoid())
Expand Down Expand Up @@ -504,7 +529,32 @@ def test_param_equal_module_optimizer(self):
max_grad_norm=1.0,
grad_sample_mode=self.GRAD_SAMPLE_MODE,
)
self.assertTrue(1, 1)

def test_parameters_match(self):
dl = self._init_data()

m1 = self._init_model()
m2 = self._init_model()
m2.load_state_dict(m1.state_dict())
# optimizer is initialized with m2 parameters
opt = torch.optim.SGD(m2.parameters(), lr=0.1)

# the values are the identical
for p1, p2 in zip(m1.parameters(), m2.parameters()):
self.assertTrue(torch.allclose(p1, p2))

privacy_engine = PrivacyEngine()
# but model parameters and optimzier parameters must be the same object,
# not just same values
with self.assertRaises(ValueError):
privacy_engine.make_private(
module=m1,
optimizer=opt,
data_loader=dl,
noise_multiplier=1.1,
max_grad_norm=1.0,
grad_sample_mode=self.GRAD_SAMPLE_MODE,
)

@given(
noise_scheduler=st.sampled_from([None, StepNoise]),
Expand Down Expand Up @@ -759,6 +809,28 @@ def _init_model(
return SampleConvNet()


class PrivacyEngineConvNetFrozenTest(BasePrivacyEngineTest, unittest.TestCase):
def _init_data(self):
ds = FakeData(
size=self.DATA_SIZE,
image_size=(1, 35, 35),
num_classes=10,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False)

def _init_model(
self, private=False, state_dict=None, model=None, **privacy_engine_kwargs
):
m = SampleConvNet()
for p in itertools.chain(m.conv1.parameters(), m.gnorm1.parameters()):
p.requires_grad = False

return m


@unittest.skipIf(
torch.__version__ < API_CUTOFF_VERSION, "not supported in this torch version"
)
Expand Down