Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
update level pruner to adapt to pruner dataparallel refactor (#1993)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cjkkkk authored Feb 10, 2020
1 parent d452a16 commit 4e21e72
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 18 deletions.
4 changes: 2 additions & 2 deletions examples/model_compress/MeanActivation_torch_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import L1FilterPruner
from nni.compression.torch import ActivationMeanRankFilterPruner
from models.cifar10.vgg import VGG


Expand Down Expand Up @@ -96,7 +96,7 @@ def main():

# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list)
pruner = ActivationMeanRankFilterPruner(model, configure_list)
model = pruner.compress()
if args.parallel:
if torch.cuda.device_count() > 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, model, config_list, activation='relu', statistics_batch_num=1
"""

super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {}
Expand All @@ -48,16 +48,23 @@ def compress(self):
"""
Compress the model, register a hook for collecting activations.
"""
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
self._instrument_layer(layer, config)
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self.collected_activation[layer.name] = []

def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu()))

layer.module.register_forward_hook(_hook)
wrapper.module.register_forward_hook(_hook)
self._wrap_model()
return self.bound_model

def get_mask(self, base_mask, activations, num_prune):
Expand Down Expand Up @@ -103,7 +110,7 @@ def calc_mask(self, layer, config, **kwargs):
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask


Expand Down
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def compress(self):
"""
if self.modules_wrapper is not None:
# already compressed
return
return self.bound_model
else:
self.modules_wrapper = []

Expand Down
21 changes: 10 additions & 11 deletions src/sdk/pynni/nni/compression/torch/pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def __init__(self, model, config_list):
"""

super().__init__(model, config_list)
self.mask_calculated_ops = set()
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable

def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer
Parameters
Expand All @@ -45,21 +45,20 @@ def calc_mask(self, layer, config):
"""

weight = layer.module.weight.data
op_name = layer.name
if op_name not in self.mask_calculated_ops:
if_calculated = kwargs["if_calculated"]

if not if_calculated:
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight}
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask
else:
assert op_name in self.mask_dict, "op_name not in the mask_dict"
mask = self.mask_dict[op_name]
return mask
return None


class AGP_Pruner(Pruner):
Expand Down Expand Up @@ -197,7 +196,7 @@ def __init__(self, model, config_list):
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable

def calc_mask(self, layer, config, **kwargs):
"""
Expand Down Expand Up @@ -232,7 +231,7 @@ def calc_mask(self, layer, config, **kwargs):
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask

class LotteryTicketPruner(Pruner):
Expand Down

0 comments on commit 4e21e72

Please sign in to comment.