-
Notifications
You must be signed in to change notification settings - Fork 1.8k
update lottery ticket pruner based on refactored compression code #1989
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,7 @@ def __init__(self, model, config_list): | |
self.modules_to_compress = None | ||
self.modules_wrapper = None | ||
self.buffers = {} | ||
self.is_wrapped = False | ||
|
||
def detect_modules_to_compress(self): | ||
""" | ||
|
@@ -63,6 +64,7 @@ def _wrap_model(self): | |
""" | ||
for wrapper in reversed(self.get_modules_wrapper()): | ||
_setattr(self.bound_model, wrapper.name, wrapper) | ||
self.is_wrapped = True | ||
|
||
def _unwrap_model(self): | ||
""" | ||
|
@@ -71,6 +73,7 @@ def _unwrap_model(self): | |
""" | ||
for wrapper in self.get_modules_wrapper(): | ||
_setattr(self.bound_model, wrapper.name, wrapper.module) | ||
self.is_wrapped = False | ||
|
||
def compress(self): | ||
""" | ||
|
@@ -263,7 +266,7 @@ class Pruner(Compressor): | |
def __init__(self, model, config_list): | ||
super().__init__(model, config_list) | ||
|
||
def calc_mask(self, layer, config): | ||
def calc_mask(self, layer, config, **kwargs): | ||
""" | ||
Pruners should overload this method to provide mask for weight tensors. | ||
The mask must have the same shape and type comparing to the weight. | ||
|
@@ -291,9 +294,12 @@ def _wrap_modules(self, layer, config): | |
the configuration for generating the mask | ||
""" | ||
_logger.info("compressing module %s.", layer.name) | ||
return PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) | ||
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) | ||
assert hasattr(layer.module, 'weight') | ||
wrapper.to(layer.module.weight.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the module is placed on GPU, buffers in the wrapped module are put on CPU, thus, I need to move the whole wrapper to the device that the module is placed |
||
return wrapper | ||
|
||
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None): | ||
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None): | ||
""" | ||
Export pruned model weights, masks and onnx model(optional) | ||
|
||
|
@@ -307,6 +313,9 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N | |
(optional) path to save onnx model | ||
input_shape : list or tuple | ||
input shape to onnx model | ||
device : torch.device | ||
device of the model, used to place the dummy input tensor for exporting onnx file. | ||
the tensor is placed on cpu if ```device``` is None | ||
""" | ||
# if self.detect_modules_to_compress() and not self.mask_dict: | ||
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks') | ||
|
@@ -335,12 +344,29 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N | |
if onnx_path is not None: | ||
assert input_shape is not None, 'input_shape must be specified to export onnx model' | ||
# input info needed | ||
if device is None: | ||
device = torch.device('cpu') | ||
input_data = torch.Tensor(*input_shape) | ||
torch.onnx.export(self.bound_model, input_data, onnx_path) | ||
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path) | ||
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) | ||
|
||
self._wrap_model() | ||
|
||
def load_model_state_dict(self, model_state): | ||
""" | ||
Load the state dict saved from unwrapped model. | ||
|
||
Parameters: | ||
----------- | ||
model_state : dict | ||
state dict saved from unwrapped model | ||
""" | ||
if self.is_wrapped: | ||
self._unwrap_model() | ||
self.bound_model.load_state_dict(model_state) | ||
self._wrap_model() | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when will There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. |
||
self.bound_model.load_state_dict(model_state) | ||
|
||
class QuantizerModuleWrapper(torch.nn.Module): | ||
def __init__(self, module, module_name, module_type, config, quantizer): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not using
check_weight
method?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is this method defined, and how to use it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nni/src/sdk/pynni/nni/compression/torch/compressor.py
Line 563 in d452a16