Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix model compression config validation #2033

Merged
merged 2 commits into from
Feb 14, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,18 @@ def select_config(self, layer):
ret = None
for config in self.config_list:
config = config.copy()
config['op_types'] = self._expand_config_op_types(config)
if layer.type not in config['op_types']:
# expand config if key `default` is in config['op_types']
if 'op_types' in config and 'default' in config['op_types']:
config['op_types'].extend(default_layers.weighted_modules)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to remove 'default' after it is expanded.

Copy link
Contributor Author

@Cjkkkk Cjkkkk Feb 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is best to just skip checking 'default' in op_types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.


# check if condition is satisified
if 'op_types' in config and layer.type not in config['op_types']:
continue
if config.get('op_names') and layer.name not in config['op_names']:
if 'op_names' in config and layer.name not in config['op_names']:
continue

ret = config
if ret is None or ret.get('exclude'):
if ret is None or 'exclude' in ret:
return None
return ret

Expand Down Expand Up @@ -188,16 +193,6 @@ def _wrap_modules(self, layer, config):
"""
raise NotImplementedError()

def _expand_config_op_types(self, config):
if config is None:
return []
expanded_op_types = []
for op_type in config.get('op_types', []):
if op_type == 'default':
expanded_op_types.extend(default_layers.weighted_modules)
else:
expanded_op_types.append(op_type)
return expanded_op_types

class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
Expand Down Expand Up @@ -229,11 +224,12 @@ def __init__(self, module, module_name, module_type, config, pruner):

# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
self.registered_buffers['weight_mask'] = self.weight_mask
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else:
self.register_buffer("bias_mask", None)

self.registered_buffers['weight_mask'] = self.weight_mask
self.registered_buffers['bias_mask'] = self.bias_mask
# register user specified buffer
for name in self.pruner.buffers:
Expand Down Expand Up @@ -297,7 +293,8 @@ def _wrap_modules(self, layer, config):
"""
_logger.info("compressing module %s.", layer.name)
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
assert hasattr(layer.module, 'weight')
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight
wrapper.to(layer.module.weight.device)
return wrapper

Expand Down