From 4e21e721a65d0ac7c8465c6b7842dd39338bb3d0 Mon Sep 17 00:00:00 2001 From: Cjkkkk <656569648@qq.com> Date: Mon, 10 Feb 2020 09:40:14 +0800 Subject: [PATCH] update level pruner to adapt to pruner dataparallel refactor (#1993) --- .../MeanActivation_torch_cifar10.py | 4 ++-- .../torch/activation_rank_filter_pruners.py | 15 +++++++++---- .../pynni/nni/compression/torch/compressor.py | 2 +- .../pynni/nni/compression/torch/pruners.py | 21 +++++++++---------- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/examples/model_compress/MeanActivation_torch_cifar10.py b/examples/model_compress/MeanActivation_torch_cifar10.py index 9d3c73bfe7..1d5e38b6ff 100644 --- a/examples/model_compress/MeanActivation_torch_cifar10.py +++ b/examples/model_compress/MeanActivation_torch_cifar10.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms -from nni.compression.torch import L1FilterPruner +from nni.compression.torch import ActivationMeanRankFilterPruner from models.cifar10.vgg import VGG @@ -96,7 +96,7 @@ def main(): # Prune model and test accuracy without fine tuning. print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10) - pruner = L1FilterPruner(model, configure_list) + pruner = ActivationMeanRankFilterPruner(model, configure_list) model = pruner.compress() if args.parallel: if torch.cuda.device_count() > 1: diff --git a/src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py b/src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py index 0bbfa72da5..fd3650a031 100644 --- a/src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py @@ -32,7 +32,7 @@ def __init__(self, model, config_list, activation='relu', statistics_batch_num=1 """ super().__init__(model, config_list) - self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable + self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable self.statistics_batch_num = statistics_batch_num self.collected_activation = {} self.hooks = {} @@ -48,16 +48,23 @@ def compress(self): """ Compress the model, register a hook for collecting activations. """ + if self.modules_wrapper is not None: + # already compressed + return self.bound_model + else: + self.modules_wrapper = [] modules_to_compress = self.detect_modules_to_compress() for layer, config in modules_to_compress: - self._instrument_layer(layer, config) + wrapper = self._wrap_modules(layer, config) + self.modules_wrapper.append(wrapper) self.collected_activation[layer.name] = [] def _hook(module_, input_, output, name=layer.name): if len(self.collected_activation[name]) < self.statistics_batch_num: self.collected_activation[name].append(self.activation(output.detach().cpu())) - layer.module.register_forward_hook(_hook) + wrapper.module.register_forward_hook(_hook) + self._wrap_model() return self.bound_model def get_mask(self, base_mask, activations, num_prune): @@ -103,7 +110,7 @@ def calc_mask(self, layer, config, **kwargs): mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune) finally: if len(self.collected_activation[layer.name]) == self.statistics_batch_num: - if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable + if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable return mask diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 55cc05aaac..1130634fdd 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -89,7 +89,7 @@ def compress(self): """ if self.modules_wrapper is not None: # already compressed - return + return self.bound_model else: self.modules_wrapper = [] diff --git a/src/sdk/pynni/nni/compression/torch/pruners.py b/src/sdk/pynni/nni/compression/torch/pruners.py index b0a27c33b3..4f992d2217 100644 --- a/src/sdk/pynni/nni/compression/torch/pruners.py +++ b/src/sdk/pynni/nni/compression/torch/pruners.py @@ -27,9 +27,9 @@ def __init__(self, model, config_list): """ super().__init__(model, config_list) - self.mask_calculated_ops = set() + self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable - def calc_mask(self, layer, config): + def calc_mask(self, layer, config, **kwargs): """ Calculate the mask of given layer Parameters @@ -45,8 +45,9 @@ def calc_mask(self, layer, config): """ weight = layer.module.weight.data - op_name = layer.name - if op_name not in self.mask_calculated_ops: + if_calculated = kwargs["if_calculated"] + + if not if_calculated: w_abs = weight.abs() k = int(weight.numel() * config['sparsity']) if k == 0: @@ -54,12 +55,10 @@ def calc_mask(self, layer, config): threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() mask_weight = torch.gt(w_abs, threshold).type_as(weight) mask = {'weight': mask_weight} - self.mask_dict.update({op_name: mask}) - self.mask_calculated_ops.add(op_name) + if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable + return mask else: - assert op_name in self.mask_dict, "op_name not in the mask_dict" - mask = self.mask_dict[op_name] - return mask + return None class AGP_Pruner(Pruner): @@ -197,7 +196,7 @@ def __init__(self, model, config_list): all_bn_weights = torch.cat(weight_list) k = int(all_bn_weights.shape[0] * config['sparsity']) self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max() - self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable + self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable def calc_mask(self, layer, config, **kwargs): """ @@ -232,7 +231,7 @@ def calc_mask(self, layer, config, **kwargs): mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight) mask_bias = mask_weight.clone() mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()} - if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable + if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable return mask class LotteryTicketPruner(Pruner):