diff --git a/src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py b/src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py index 5bfcc16804..a793c0dcf6 100644 --- a/src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py +++ b/src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py @@ -1,9 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import logging import torch from .infer_shape import ModuleMasks +_logger = logging.getLogger(__name__) + replace_module = { 'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask), 'Conv2d': lambda module, mask: replace_conv2d(module, mask), @@ -16,6 +19,7 @@ def no_replace(module, mask): """ No need to replace """ + _logger.debug("no need to replace") return module def replace_linear(linear, mask): @@ -37,9 +41,8 @@ def replace_linear(linear, mask): assert mask.output_mask is None assert not mask.param_masks index = mask.input_mask.mask_index[-1] - print(mask.input_mask.mask_index) in_features = index.size()[0] - print('linear: ', in_features) + _logger.debug("replace linear with new in_features: %d", in_features) new_linear = torch.nn.Linear(in_features=in_features, out_features=linear.out_features, bias=linear.bias is not None) @@ -67,7 +70,7 @@ def replace_batchnorm2d(norm, mask): assert 'weight' in mask.param_masks and 'bias' in mask.param_masks index = mask.param_masks['weight'].mask_index[0] num_features = index.size()[0] - print("replace batchnorm2d: ", num_features, index) + _logger.debug("replace batchnorm2d with num_features: %d", num_features) new_norm = torch.nn.BatchNorm2d(num_features=num_features, eps=norm.eps, momentum=norm.momentum, @@ -106,6 +109,7 @@ def replace_conv2d(conv, mask): else: out_channels_index = mask.output_mask.mask_index[1] out_channels = out_channels_index.size()[0] + _logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels) new_conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=conv.kernel_size, @@ -128,6 +132,5 @@ def replace_conv2d(conv, mask): assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks" new_conv.weight.data.copy_(tmp_weight_data) if conv.bias is not None: - print('final conv.bias is not None') new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data) return new_conv diff --git a/src/sdk/pynni/nni/compression/speedup/torch/compressor.py b/src/sdk/pynni/nni/compression/speedup/torch/compressor.py index ae6b7ce015..67a3186462 100644 --- a/src/sdk/pynni/nni/compression/speedup/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/speedup/torch/compressor.py @@ -158,7 +158,7 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node): """ # TODO: scope name could be empty node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)]) - #print('node_name: ', node_name) + _logger.debug("expand non-prim node, node name: %s", node_name) self.global_count += 1 op_type = node.kind() @@ -173,7 +173,6 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node): input_name = _input.debugName() if input_name in output_to_node and output_to_node[input_name] in nodes: predecessor_node = output_to_node[input_name] - #print("predecessor_node: ", predecessor_node) if predecessor_node.kind().startswith('prim::'): node_group.append(predecessor_node) node_queue.put(predecessor_node) @@ -231,7 +230,7 @@ def _build_graph(self): """ graph = self.trace_graph.graph # if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here - #print(graph) + #_logger.debug(graph) # build output mapping, from output debugName to its node output_to_node = dict() # build input mapping, from input debugName to its node @@ -301,10 +300,8 @@ def _build_graph(self): m_inputs.append(_input) elif not output_to_node[_input] in nodes: m_inputs.append(_input) - print("module node_name: ", module_name) if module_name == '': - for n in nodes: - print(n) + _logger.warning("module_name is empty string") g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes) self.g_nodes.append(g_node) @@ -345,10 +342,7 @@ def _find_predecessors(self, module_name): predecessors = [] for _input in self.name_to_gnode[module_name].inputs: if not _input in self.output_to_gnode: - print(_input) - if not _input in self.output_to_gnode: - # TODO: check _input which does not have node - print("output with no gnode: ", _input) + _logger.debug("cannot find gnode with %s as its output", _input) else: g_node = self.output_to_gnode[_input] predecessors.append(g_node.name) @@ -407,15 +401,15 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non self.inferred_masks[module_name] = module_masks m_type = self.name_to_gnode[module_name].op_type - print("infer_module_mask: {}, module type: {}".format(module_name, m_type)) + _logger.debug("infer mask of module %s with op_type %s", module_name, m_type) if mask is not None: - #print("mask is not None") + _logger.debug("mask is not None") if not m_type in infer_from_mask: raise RuntimeError("Has not supported infering \ input/output shape from mask for module/function: `{}`".format(m_type)) input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask) if in_shape is not None: - #print("in_shape is not None") + _logger.debug("in_shape is not None") if not m_type in infer_from_inshape: raise RuntimeError("Has not supported infering \ output shape from input shape for module/function: `{}`".format(m_type)) @@ -426,23 +420,19 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non else: output_cmask = infer_from_inshape[m_type](module_masks, in_shape) if out_shape is not None: - #print("out_shape is not None") + _logger.debug("out_shape is not None") if not m_type in infer_from_outshape: raise RuntimeError("Has not supported infering \ input shape from output shape for module/function: `{}`".format(m_type)) input_cmask = infer_from_outshape[m_type](module_masks, out_shape) if input_cmask: - #print("input_cmask is not None") predecessors = self._find_predecessors(module_name) for _module_name in predecessors: - print("input_cmask, module_name: ", _module_name) self.infer_module_mask(_module_name, out_shape=input_cmask) if output_cmask: - #print("output_cmask is not None") successors = self._find_successors(module_name) for _module_name in successors: - print("output_cmask, module_name: ", _module_name) self.infer_module_mask(_module_name, in_shape=output_cmask) def infer_modules_masks(self): @@ -463,16 +453,19 @@ def replace_compressed_modules(self): """ for module_name in self.inferred_masks: g_node = self.name_to_gnode[module_name] - print(module_name, g_node.op_type) + _logger.debug("replace %s, in %s type, with op_type %s", + module_name, g_node.type, g_node.op_type) if g_node.type == 'module': super_module, leaf_module = get_module_by_name(self.bound_model, module_name) m_type = g_node.op_type if not m_type in replace_module: raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type)) + _logger.info("replace module (name: %s, op_type: %s)", module_name, m_type) compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name]) setattr(super_module, module_name.split('.')[-1], compressed_module) elif g_node.type == 'func': - print("Warning: Cannot replace func...") + _logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type", + module_name, g_node.op_type) else: raise RuntimeError("Unsupported GNode type: {}".format(g_node.type)) @@ -482,10 +475,12 @@ def speedup_model(self): first, do mask/shape inference, second, replace modules """ - #print("start to compress") + _logger.info("start to speed up the model") + _logger.info("infer module masks...") self.infer_modules_masks() + _logger.info("replace compressed modules...") self.replace_compressed_modules() - #print("finished compressing") + _logger.info("speedup done") # resume the model mode to that before the model is speed up if self.is_training: self.bound_model.train()