Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

update lottery ticket pruner based on refactored compression code #1989

Merged

Conversation

QuanluZhang
Copy link
Contributor

@QuanluZhang QuanluZhang commented Feb 3, 2020

  1. update lottery ticket pruner based on refactored compression code
  2. fix bug: allow users to specify device of the model when exporting mask and model, used to place dummy input tensor for exporting onnx model.

self._unwrap_model()
self.bound_model.load_state_dict(model_state)
self._wrap_model()
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when will is_wrapped be false ? Before calling compress method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.
it is possible that users call load_model_state before compress, we need an approach to check whether the model has been wrapped or not. any suggestion?

@@ -291,7 +294,10 @@ def _wrap_modules(self, layer, config):
the configuration for generating the mask
"""
_logger.info("compressing module %s.", layer.name)
return PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
assert hasattr(layer.module, 'weight')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not using check_weight method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this method defined, and how to use it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _check_weight(module):

return PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
assert hasattr(layer.module, 'weight')
wrapper.to(layer.module.weight.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the module is placed on GPU, buffers in the wrapped module are put on CPU, thus, I need to move the whole wrapper to the device that the module is placed

@QuanluZhang QuanluZhang closed this Feb 5, 2020
@QuanluZhang QuanluZhang reopened this Feb 5, 2020
@QuanluZhang QuanluZhang merged commit d452a16 into microsoft:dev-pruner-dataparallel Feb 5, 2020
@QuanluZhang QuanluZhang deleted the pruner-fix-bugs branch February 20, 2020 13:58
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants