Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
update lottery ticket pruner based on refactored compression code (#1989
Browse files Browse the repository at this point in the history
)
  • Loading branch information
QuanluZhang authored Feb 5, 2020
1 parent 6b0ecee commit d452a16
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 31 deletions.
2 changes: 2 additions & 0 deletions examples/model_compress/lottery_torch_mnist_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def test(model, test_loader, criterion):
pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress()

#model = nn.DataParallel(model)

for i in pruner.get_prune_iterations():
pruner.prune_iteration_start()
loss = 0
Expand Down
2 changes: 1 addition & 1 deletion examples/model_compress/multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test(model, test_loader, criterion, device):
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True)

device = torch.device("cuda: 0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = fc1()

criterion = nn.CrossEntropyLoss()
Expand Down
34 changes: 30 additions & 4 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, model, config_list):
self.modules_to_compress = None
self.modules_wrapper = None
self.buffers = {}
self.is_wrapped = False

def detect_modules_to_compress(self):
"""
Expand All @@ -63,6 +64,7 @@ def _wrap_model(self):
"""
for wrapper in reversed(self.get_modules_wrapper()):
_setattr(self.bound_model, wrapper.name, wrapper)
self.is_wrapped = True

def _unwrap_model(self):
"""
Expand All @@ -71,6 +73,7 @@ def _unwrap_model(self):
"""
for wrapper in self.get_modules_wrapper():
_setattr(self.bound_model, wrapper.name, wrapper.module)
self.is_wrapped = False

def compress(self):
"""
Expand Down Expand Up @@ -263,7 +266,7 @@ class Pruner(Compressor):
def __init__(self, model, config_list):
super().__init__(model, config_list)

def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
Expand Down Expand Up @@ -291,9 +294,12 @@ def _wrap_modules(self, layer, config):
the configuration for generating the mask
"""
_logger.info("compressing module %s.", layer.name)
return PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
assert hasattr(layer.module, 'weight')
wrapper.to(layer.module.weight.device)
return wrapper

def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None):
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export pruned model weights, masks and onnx model(optional)
Expand All @@ -307,6 +313,9 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
# if self.detect_modules_to_compress() and not self.mask_dict:
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
Expand Down Expand Up @@ -335,12 +344,29 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data, onnx_path)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)

self._wrap_model()

def load_model_state_dict(self, model_state):
"""
Load the state dict saved from unwrapped model.
Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if self.is_wrapped:
self._unwrap_model()
self.bound_model.load_state_dict(model_state)
self._wrap_model()
else:
self.bound_model.load_state_dict(model_state)

class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
Expand Down
49 changes: 23 additions & 26 deletions src/sdk/pynni/nni/compression/torch/pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,38 +290,23 @@ def _validate_config(self, config_list):
prune_iterations = config['prune_iterations']
return prune_iterations

def _print_masks(self, print_mask=False):
torch.set_printoptions(threshold=1000)
for op_name in self.mask_dict.keys():
mask = self.mask_dict[op_name]
print('op name: ', op_name)
if print_mask:
print('mask: ', mask)
# calculate current sparsity
mask_num = mask['weight'].sum().item()
mask_size = mask['weight'].numel()
print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default')

def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration
return max(1 - curr_keep_ratio, 0)

def _calc_mask(self, weight, sparsity, op_name):
def _calc_mask(self, weight, sparsity, curr_w_mask):
if self.curr_prune_iteration == 0:
mask = torch.ones(weight.shape).type_as(weight)
else:
curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None
curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask['weight']
w_abs = weight.abs() * curr_w_mask
k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
return {'weight': mask}

def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Generate mask for the given ``weight``.
Expand All @@ -331,15 +316,17 @@ def calc_mask(self, layer, config):
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
tensor
The mask for this weight
The mask for this weight, it is ```None``` because this pruner
calculates and assigns masks in ```prune_iteration_start```,
no need to do anything in this function.
"""
assert self.mask_dict.get(layer.name) is not None, 'Please call iteration_start before training'
mask = self.mask_dict[layer.name]
return mask
return None

def get_prune_iterations(self):
"""
Expand All @@ -364,16 +351,26 @@ def prune_iteration_start(self):
self.curr_prune_iteration += 1
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'

modules_wrapper = self.get_modules_wrapper()
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
module_wrapper = None
for wrapper in modules_wrapper:
if wrapper.name == layer.name:
module_wrapper = wrapper
break
assert module_wrapper is not None

sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, layer.name)
self.mask_dict.update({layer.name: mask})
self._print_masks()
mask = self._calc_mask(layer.module.weight.data, sparsity, module_wrapper.weight_mask)
# TODO: directly use weight_mask is not good
module_wrapper.weight_mask.copy_(mask['weight'])
# there is no mask for bias

# reinit weights back to original after new masks are generated
if self.reset_weights:
self._model.load_state_dict(self._model_state)
# should use this member function to reset model weights
self.load_model_state_dict(self._model_state)
self._optimizer.load_state_dict(self._optimizer_state)
if self._lr_scheduler is not None:
self._lr_scheduler.load_state_dict(self._scheduler_state)

0 comments on commit d452a16

Please sign in to comment.