Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix pruners for DataParallel support (#2003)
Browse files Browse the repository at this point in the history
  • Loading branch information
chicm-ms authored Feb 10, 2020
1 parent 4e21e72 commit c7d5803
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 35 deletions.
19 changes: 13 additions & 6 deletions examples/model_compress/fpgm_torch_mnist.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import FPGMPruner

class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
Expand All @@ -27,8 +28,14 @@ def _get_conv_weight_sparsity(self, conv_layer):
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters

def print_conv_filter_sparsity(self):
conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2)
if isinstance(self.conv1, nn.Conv2d):
conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2)
else:
# self.conv1 is wrapped as PrunerModuleWrapper
conv1_data = self._get_conv_weight_sparsity(self.conv1.module)
conv2_data = self._get_conv_weight_sparsity(self.conv2.module)

print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2]))
print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2]))

Expand Down
3 changes: 1 addition & 2 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def forward(self, *inputs):
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None:
if mask is not None and 'bias' in mask:
self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
Expand Down Expand Up @@ -565,4 +565,3 @@ def _check_weight(module):
return isinstance(module.weight.data, torch.Tensor)
except AttributeError:
return False

48 changes: 26 additions & 22 deletions src/sdk/pynni/nni/compression/torch/pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,42 +83,47 @@ def __init__(self, model, config_list):

super().__init__(model, config_list)
self.now_epoch = 0
self.if_init_list = {}
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable

def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
dictionary for storing masks
"""

weight = layer.module.weight.data
op_name = layer.name
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
and (self.now_epoch - start_epoch) % freq == 0:
mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
else:
new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})

if_calculated = kwargs["if_calculated"]
if if_calculated:
return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None

mask = {'weight': torch.ones(weight.shape).type_as(weight)}
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs()
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable

return new_mask

def compute_target_sparsity(self, config):
Expand Down Expand Up @@ -164,9 +169,8 @@ def update_epoch(self, epoch):

if epoch > 0:
self.now_epoch = epoch
for k in self.if_init_list.keys():
self.if_init_list[k] = True

for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable

class SlimPruner(Pruner):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, model, config_list):
"""

super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable

def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
Expand Down Expand Up @@ -69,7 +69,7 @@ def calc_mask(self, layer, config, **kwargs):
return mask
mask = self.get_mask(mask, weight, num_prune)
finally:
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask


Expand Down Expand Up @@ -257,4 +257,5 @@ def _get_distance_sum(self, weight, in_idx, out_idx):
return x.sum()

def update_epoch(self, epoch):
self.mask_calculated_ops = set()
for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
2 changes: 0 additions & 2 deletions src/sdk/pynni/tests/test_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def test_torch_fpgm_pruner(self):
masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))

pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
Expand All @@ -159,7 +158,6 @@ def test_tf_fpgm_pruner(self):

assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))

pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
Expand Down

0 comments on commit c7d5803

Please sign in to comment.