Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

update lottery ticket pruner based on refactored compression code #1989

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/model_compress/lottery_torch_mnist_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def test(model, test_loader, criterion):
pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress()

#model = nn.DataParallel(model)

for i in pruner.get_prune_iterations():
pruner.prune_iteration_start()
loss = 0
Expand Down
2 changes: 1 addition & 1 deletion examples/model_compress/multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test(model, test_loader, criterion, device):
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True)

device = torch.device("cuda: 0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = fc1()

criterion = nn.CrossEntropyLoss()
Expand Down
34 changes: 30 additions & 4 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, model, config_list):
self.modules_to_compress = None
self.modules_wrapper = None
self.buffers = {}
self.is_wrapped = False

def detect_modules_to_compress(self):
"""
Expand All @@ -63,6 +64,7 @@ def _wrap_model(self):
"""
for wrapper in reversed(self.get_modules_wrapper()):
_setattr(self.bound_model, wrapper.name, wrapper)
self.is_wrapped = True

def _unwrap_model(self):
"""
Expand All @@ -71,6 +73,7 @@ def _unwrap_model(self):
"""
for wrapper in self.get_modules_wrapper():
_setattr(self.bound_model, wrapper.name, wrapper.module)
self.is_wrapped = False

def compress(self):
"""
Expand Down Expand Up @@ -263,7 +266,7 @@ class Pruner(Compressor):
def __init__(self, model, config_list):
super().__init__(model, config_list)

def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
Expand Down Expand Up @@ -291,9 +294,12 @@ def _wrap_modules(self, layer, config):
the configuration for generating the mask
"""
_logger.info("compressing module %s.", layer.name)
return PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
assert hasattr(layer.module, 'weight')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not using check_weight method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this method defined, and how to use it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _check_weight(module):

wrapper.to(layer.module.weight.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the module is placed on GPU, buffers in the wrapped module are put on CPU, thus, I need to move the whole wrapper to the device that the module is placed

return wrapper

def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None):
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export pruned model weights, masks and onnx model(optional)

Expand All @@ -307,6 +313,9 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
# if self.detect_modules_to_compress() and not self.mask_dict:
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
Expand Down Expand Up @@ -335,12 +344,29 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data, onnx_path)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)

self._wrap_model()

def load_model_state_dict(self, model_state):
"""
Load the state dict saved from unwrapped model.

Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if self.is_wrapped:
self._unwrap_model()
self.bound_model.load_state_dict(model_state)
self._wrap_model()
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when will is_wrapped be false ? Before calling compress method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.
it is possible that users call load_model_state before compress, we need an approach to check whether the model has been wrapped or not. any suggestion?

self.bound_model.load_state_dict(model_state)

class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
Expand Down
49 changes: 23 additions & 26 deletions src/sdk/pynni/nni/compression/torch/pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,38 +290,23 @@ def _validate_config(self, config_list):
prune_iterations = config['prune_iterations']
return prune_iterations

def _print_masks(self, print_mask=False):
torch.set_printoptions(threshold=1000)
for op_name in self.mask_dict.keys():
mask = self.mask_dict[op_name]
print('op name: ', op_name)
if print_mask:
print('mask: ', mask)
# calculate current sparsity
mask_num = mask['weight'].sum().item()
mask_size = mask['weight'].numel()
print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default')

def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration
return max(1 - curr_keep_ratio, 0)

def _calc_mask(self, weight, sparsity, op_name):
def _calc_mask(self, weight, sparsity, curr_w_mask):
if self.curr_prune_iteration == 0:
mask = torch.ones(weight.shape).type_as(weight)
else:
curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None
curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask['weight']
w_abs = weight.abs() * curr_w_mask
k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
return {'weight': mask}

def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Generate mask for the given ``weight``.

Expand All @@ -331,15 +316,17 @@ def calc_mask(self, layer, config):
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information

Returns
-------
tensor
The mask for this weight
The mask for this weight, it is ```None``` because this pruner
calculates and assigns masks in ```prune_iteration_start```,
no need to do anything in this function.
"""
assert self.mask_dict.get(layer.name) is not None, 'Please call iteration_start before training'
mask = self.mask_dict[layer.name]
return mask
return None

def get_prune_iterations(self):
"""
Expand All @@ -364,16 +351,26 @@ def prune_iteration_start(self):
self.curr_prune_iteration += 1
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'

modules_wrapper = self.get_modules_wrapper()
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
module_wrapper = None
for wrapper in modules_wrapper:
if wrapper.name == layer.name:
module_wrapper = wrapper
break
assert module_wrapper is not None

sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, layer.name)
self.mask_dict.update({layer.name: mask})
self._print_masks()
mask = self._calc_mask(layer.module.weight.data, sparsity, module_wrapper.weight_mask)
# TODO: directly use weight_mask is not good
module_wrapper.weight_mask.copy_(mask['weight'])
# there is no mask for bias

# reinit weights back to original after new masks are generated
if self.reset_weights:
self._model.load_state_dict(self._model_state)
# should use this member function to reset model weights
self.load_model_state_dict(self._model_state)
self._optimizer.load_state_dict(self._optimizer_state)
if self._lr_scheduler is not None:
self._lr_scheduler.load_state_dict(self._scheduler_state)