diff --git a/examples/model_compress/fpgm_torch_mnist.py b/examples/model_compress/fpgm_torch_mnist.py index e9c70be56c..db141b37d9 100644 --- a/examples/model_compress/fpgm_torch_mnist.py +++ b/examples/model_compress/fpgm_torch_mnist.py @@ -1,4 +1,5 @@ import torch +import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from nni.compression.torch import FPGMPruner @@ -6,10 +7,10 @@ class Mnist(torch.nn.Module): def __init__(self): super().__init__() - self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) - self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) - self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) - self.fc2 = torch.nn.Linear(500, 10) + self.conv1 = nn.Conv2d(1, 20, 5, 1) + self.conv2 = nn.Conv2d(20, 50, 5, 1) + self.fc1 = nn.Linear(4 * 4 * 50, 500) + self.fc2 = nn.Linear(500, 10) def forward(self, x): x = F.relu(self.conv1(x)) @@ -27,8 +28,14 @@ def _get_conv_weight_sparsity(self, conv_layer): return num_zero_filters, num_filters, float(num_zero_filters)/num_filters def print_conv_filter_sparsity(self): - conv1_data = self._get_conv_weight_sparsity(self.conv1) - conv2_data = self._get_conv_weight_sparsity(self.conv2) + if isinstance(self.conv1, nn.Conv2d): + conv1_data = self._get_conv_weight_sparsity(self.conv1) + conv2_data = self._get_conv_weight_sparsity(self.conv2) + else: + # self.conv1 is wrapped as PrunerModuleWrapper + conv1_data = self._get_conv_weight_sparsity(self.conv1.module) + conv2_data = self._get_conv_weight_sparsity(self.conv2.module) + print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2])) print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2])) diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 1130634fdd..7f5f8e0a93 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -246,7 +246,7 @@ def forward(self, *inputs): self.module.weight.data = self.module.weight.data.mul_(self.weight_mask) # apply mask to bias if hasattr(self.module, 'bias') and self.module.bias is not None: - if mask is not None: + if mask is not None and 'bias' in mask: self.bias_mask.copy_(mask['bias']) self.module.bias.data = self.module.bias.data.mul_(self.bias_mask) return self.module(*inputs) @@ -565,4 +565,3 @@ def _check_weight(module): return isinstance(module.weight.data, torch.Tensor) except AttributeError: return False - \ No newline at end of file diff --git a/src/sdk/pynni/nni/compression/torch/pruners.py b/src/sdk/pynni/nni/compression/torch/pruners.py index 4f992d2217..15c6b78262 100644 --- a/src/sdk/pynni/nni/compression/torch/pruners.py +++ b/src/sdk/pynni/nni/compression/torch/pruners.py @@ -83,17 +83,20 @@ def __init__(self, model, config_list): super().__init__(model, config_list) self.now_epoch = 0 - self.if_init_list = {} + self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable - def calc_mask(self, layer, config): + def calc_mask(self, layer, config, **kwargs): """ - Calculate the mask of given layer + Calculate the mask of given layer. + Scale factors with the smallest absolute value in the BN layer are masked. Parameters ---------- layer : LayerInfo the layer to instrument the compression operation config : dict layer's pruning config + kwargs: dict + buffers registered in __init__ function Returns ------- dict @@ -101,24 +104,26 @@ def calc_mask(self, layer, config): """ weight = layer.module.weight.data - op_name = layer.name start_epoch = config.get('start_epoch', 0) freq = config.get('frequency', 1) - if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \ - and (self.now_epoch - start_epoch) % freq == 0: - mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)}) - target_sparsity = self.compute_target_sparsity(config) - k = int(weight.numel() * target_sparsity) - if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: - return mask - # if we want to generate new mask, we should update weigth first - w_abs = weight.abs() * mask - threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() - new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)} - self.mask_dict.update({op_name: new_mask}) - self.if_init_list.update({op_name: False}) - else: - new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)}) + + if_calculated = kwargs["if_calculated"] + if if_calculated: + return None + if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0): + return None + + mask = {'weight': torch.ones(weight.shape).type_as(weight)} + target_sparsity = self.compute_target_sparsity(config) + k = int(weight.numel() * target_sparsity) + if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: + return mask + # if we want to generate new mask, we should update weigth first + w_abs = weight.abs() + threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() + new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)} + if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable + return new_mask def compute_target_sparsity(self, config): @@ -164,9 +169,8 @@ def update_epoch(self, epoch): if epoch > 0: self.now_epoch = epoch - for k in self.if_init_list.keys(): - self.if_init_list[k] = True - + for wrapper in self.get_modules_wrapper(): + wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable class SlimPruner(Pruner): """ diff --git a/src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py b/src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py index 00b0a3cf41..7357567def 100644 --- a/src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py @@ -27,7 +27,7 @@ def __init__(self, model, config_list): """ super().__init__(model, config_list) - self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable + self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable def get_mask(self, base_mask, weight, num_prune): raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) @@ -69,7 +69,7 @@ def calc_mask(self, layer, config, **kwargs): return mask mask = self.get_mask(mask, weight, num_prune) finally: - if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable + if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable return mask @@ -257,4 +257,5 @@ def _get_distance_sum(self, weight, in_idx, out_idx): return x.sum() def update_epoch(self, epoch): - self.mask_calculated_ops = set() + for wrapper in self.get_modules_wrapper(): + wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index 168b949021..1992c19069 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -138,7 +138,6 @@ def test_torch_fpgm_pruner(self): masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0)) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) - pruner.update_epoch(1) model.conv2.weight.data = torch.tensor(w).float() masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0)) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.])) @@ -159,7 +158,6 @@ def test_tf_fpgm_pruner(self): assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) - pruner.update_epoch(1) model.layers[2].set_weights([weights[0], weights[1].numpy()]) masks = pruner.calc_mask(layer, config_list[1]).numpy() masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])