-
Notifications
You must be signed in to change notification settings - Fork 1.8k
update lottery ticket pruner based on refactored compression code #1989
update lottery ticket pruner based on refactored compression code #1989
Conversation
QuanluZhang
commented
Feb 3, 2020
•
edited
Loading
edited
- update lottery ticket pruner based on refactored compression code
- fix bug: allow users to specify device of the model when exporting mask and model, used to place dummy input tensor for exporting onnx model.
self._unwrap_model() | ||
self.bound_model.load_state_dict(model_state) | ||
self._wrap_model() | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when will is_wrapped
be false ? Before calling compress
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes.
it is possible that users call load_model_state
before compress
, we need an approach to check whether the model has been wrapped or not. any suggestion?
@@ -291,7 +294,10 @@ def _wrap_modules(self, layer, config): | |||
the configuration for generating the mask | |||
""" | |||
_logger.info("compressing module %s.", layer.name) | |||
return PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) | |||
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) | |||
assert hasattr(layer.module, 'weight') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not using check_weight
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is this method defined, and how to use it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _check_weight(module): |
return PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) | ||
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) | ||
assert hasattr(layer.module, 'weight') | ||
wrapper.to(layer.module.weight.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the module is placed on GPU, buffers in the wrapped module are put on CPU, thus, I need to move the whole wrapper to the device that the module is placed