diff --git a/examples/model_compress/QAT_torch_quantizer.py b/examples/model_compress/QAT_torch_quantizer.py index 04747c9f10..f1dead41f9 100644 --- a/examples/model_compress/QAT_torch_quantizer.py +++ b/examples/model_compress/QAT_torch_quantizer.py @@ -35,7 +35,6 @@ def train(model, quantizer, device, train_loader, optimizer): loss = F.nll_loss(output, target) loss.backward() optimizer.step() - quantizer.step() if batch_idx % 100 == 0: print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py index ce6c8bc902..7f9c3b144a 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py @@ -100,7 +100,7 @@ def get_bits_length(config, quant_type): class QAT_Quantizer(Quantizer): - """Quantizer using the DoReFa scheme, as defined in: + """Quantizer defined in: Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf """ @@ -227,20 +227,17 @@ class DoReFaQuantizer(Quantizer): (https://arxiv.org/abs/1606.06160) """ def __init__(self, model, config_list): - """ - config_list: supported keys: - - q_bits - """ super().__init__(model, config_list) def quantize_weight(self, weight, config, **kwargs): + weight_bits = get_bits_length(config, 'weight') out = weight.tanh() out = out / (2 * out.abs().max()) + 0.5 - out = self.quantize(out, config['q_bits']) + out = self.quantize(out, weight_bits) out = 2 * out -1 return out def quantize(self, input_ri, q_bits): scale = pow(2, q_bits)-1 output = torch.round(input_ri*scale)/scale - return output + return output \ No newline at end of file diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 18a2d15fe1..e7965d837b 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -250,6 +250,10 @@ class Quantizer(Compressor): Base quantizer for pytorch quantizer """ + def __init__(self, model, config_list): + super().__init__(model, config_list) + self.quant_grad = QuantGrad + def quantize_weight(self, weight, config, op, op_type, op_name): """ quantize should overload this method to quantize weight. @@ -262,7 +266,7 @@ def quantize_weight(self, weight, config, op, op_type, op_name): config : dict the configuration for weight quantization """ - raise NotImplementedError("Quantizer must overload quantize_weight()") + raise NotImplementedError('Quantizer must overload quantize_weight()') def quantize_output(self, output, config, op, op_type, op_name): """ @@ -276,7 +280,7 @@ def quantize_output(self, output, config, op, op_type, op_name): config : dict the configuration for output quantization """ - raise NotImplementedError("Quantizer must overload quantize_output()") + raise NotImplementedError('Quantizer must overload quantize_output()') def quantize_input(self, *inputs, config, op, op_type, op_name): """ @@ -290,7 +294,7 @@ def quantize_input(self, *inputs, config, op, op_type, op_name): config : dict the configuration for inputs quantization """ - raise NotImplementedError("Quantizer must overload quantize_input()") + raise NotImplementedError('Quantizer must overload quantize_input()') def _instrument_layer(self, layer, config): @@ -305,62 +309,93 @@ def _instrument_layer(self, layer, config): the configuration for quantization """ assert layer._forward is None, 'Each model can only be compressed once' - assert "quant_types" in config, 'must provide quant_types in config' - assert isinstance(config["quant_types"], list), 'quant_types must be list type' - assert "quant_bits" in config, 'must provide quant_bits in config' - assert isinstance(config["quant_bits"], int) or isinstance(config["quant_bits"], dict), 'quant_bits must be dict type or int type' + assert 'quant_types' in config, 'must provide quant_types in config' + assert isinstance(config['quant_types'], list), 'quant_types must be list type' + assert 'quant_bits' in config, 'must provide quant_bits in config' + assert isinstance(config['quant_bits'], int) or isinstance(config['quant_bits'], dict), 'quant_bits must be dict type or int type' - if isinstance(config["quant_bits"], dict): - for quant_type in config["quant_types"]: - assert quant_type in config["quant_bits"], 'bits length for %s must be specified in quant_bits dict' % quant_type + if isinstance(config['quant_bits'], dict): + for quant_type in config['quant_types']: + assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type - if 'weight' in config["quant_types"]: + if 'weight' in config['quant_types']: if not _check_weight(layer.module): _logger.warning('Module %s does not have parameter "weight"', layer.name) + else: + # old_weight is used to store origin weight and weight is used to store quantized weight + # the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf + # if weight is leaf , then old_weight can not be updated. + layer.module.register_parameter('old_weight', torch.nn.Parameter(layer.module.weight)) + delattr(layer.module, 'weight') + layer.module.register_buffer('weight', layer.module.old_weight) + layer._forward = layer.module.forward def new_forward(*inputs): - if 'input' in config["quant_types"]: - inputs = straight_through_quantize_input.apply(inputs, self, config, layer) + if 'input' in config['quant_types']: + inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer) - if 'weight' in config["quant_types"] and _check_weight(layer.module): - weight = layer.module.weight.data - new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) - layer.module.weight.data = new_weight + if 'weight' in config['quant_types'] and _check_weight(layer.module): + new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer) + layer.module.weight = new_weight result = layer._forward(*inputs) - layer.module.weight.data = weight else: result = layer._forward(*inputs) - if 'output' in config["quant_types"]: - result = straight_through_quantize_output.apply(result, self, config, layer) + if 'output' in config['quant_types']: + result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer) return result layer.module.forward = new_forward +class QuantType: + """ + Enum class for quantization type. + """ + QUANT_INPUT = 0 + QUANT_WEIGHT = 1 + QUANT_OUTPUT = 2 -class straight_through_quantize_output(torch.autograd.Function): +class QuantGrad(torch.autograd.Function): + """ + Base class for overriding backward function of quantization operation. + """ @staticmethod - def forward(ctx, output, quantizer, config, layer): - return quantizer.quantize_output(output, config, op=layer.module, op_type=layer.type, op_name=layer.name) + def quant_backward(tensor, grad_output, quant_type): + """ + This method should be overrided by subclass to provide customized backward function, + default implementation is Straight-Through Estimator - @staticmethod - def backward(ctx, grad_output): - # Straight-through estimator - return grad_output, None, None, None + Parameters + ---------- + tensor : Tensor + input of quantization operation + grad_output : Tensor + gradient of the output of quantization operation + quant_type : QuantType + the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, + you can define different behavior for different types. -class straight_through_quantize_input(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, quantizer, config, layer): - return quantizer.quantize_input(inputs, config, op=layer.module, op_type=layer.type, op_name=layer.name) + Returns + ------- + tensor + gradient of the input of quantization operation + """ + return grad_output @staticmethod - def backward(ctx, grad_output): - # Straight-through estimator - return grad_output, None, None, None + def forward(ctx, tensor, quant_type, quant_func, config, layer): + ctx.save_for_backward(tensor, torch.Tensor([quant_type])) + return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name) + + @classmethod + def backward(cls, ctx, grad_output): + tensor, quant_type = ctx.saved_variables + output = cls.quant_backward(tensor, grad_output, quant_type) + return output, None, None, None, None def _check_weight(module): try: - return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor) + return isinstance(module.weight.data, torch.Tensor) except AttributeError: return False