From 9d0d84d40197ac2dfa255212c31f767d7f8f9786 Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Tue, 11 Feb 2020 21:51:31 +0800 Subject: [PATCH 1/2] update config validation --- .../pynni/nni/compression/torch/compressor.py | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 1ba3370330..abcbbf7724 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -149,13 +149,18 @@ def select_config(self, layer): ret = None for config in self.config_list: config = config.copy() - config['op_types'] = self._expand_config_op_types(config) - if layer.type not in config['op_types']: + # expand config if key `default` is in config['op_types'] + if 'op_types' in config and 'default' in config['op_types']: + config['op_types'].extend(default_layers.weighted_modules) + + # check if condition is satisified + if 'op_types' in config and layer.type not in config['op_types']: continue - if config.get('op_names') and layer.name not in config['op_names']: + if 'op_names' in config and layer.name not in config['op_names']: continue + ret = config - if ret is None or ret.get('exclude'): + if ret is None or 'exclude' in ret: return None return ret @@ -188,16 +193,6 @@ def _wrap_modules(self, layer, config): """ raise NotImplementedError() - def _expand_config_op_types(self, config): - if config is None: - return [] - expanded_op_types = [] - for op_type in config.get('op_types', []): - if op_type == 'default': - expanded_op_types.extend(default_layers.weighted_modules) - else: - expanded_op_types.append(op_type) - return expanded_op_types class PrunerModuleWrapper(torch.nn.Module): def __init__(self, module, module_name, module_type, config, pruner): @@ -229,11 +224,12 @@ def __init__(self, module, module_name, module_type, config, pruner): # register buffer for mask self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) - self.registered_buffers['weight_mask'] = self.weight_mask if hasattr(self.module, 'bias') and self.module.bias is not None: self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) else: self.register_buffer("bias_mask", None) + + self.registered_buffers['weight_mask'] = self.weight_mask self.registered_buffers['bias_mask'] = self.bias_mask # register user specified buffer for name in self.pruner.buffers: @@ -297,7 +293,8 @@ def _wrap_modules(self, layer, config): """ _logger.info("compressing module %s.", layer.name) wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) - assert hasattr(layer.module, 'weight') + assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name + # move newly registered buffers to the same device of weight wrapper.to(layer.module.weight.device) return wrapper From b645b9948d53fa2c587f4fa977a8a4548be48710 Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Thu, 13 Feb 2020 12:33:20 +0800 Subject: [PATCH 2/2] remove default key --- src/sdk/pynni/nni/compression/torch/compressor.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index abcbbf7724..14cbc194f7 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -151,7 +151,13 @@ def select_config(self, layer): config = config.copy() # expand config if key `default` is in config['op_types'] if 'op_types' in config and 'default' in config['op_types']: - config['op_types'].extend(default_layers.weighted_modules) + expanded_op_types = [] + for op_type in config['op_types']: + if op_type == 'default': + expanded_op_types.extend(default_layers.weighted_modules) + else: + expanded_op_types.append(op_type) + config['op_types'] = expanded_op_types # check if condition is satisified if 'op_types' in config and layer.type not in config['op_types']: