diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 39bcfc68e421..56df39fdaa30 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3713,6 +3713,7 @@ def from_pytorch( custom_convert_map=None, default_dtype="float32", use_parser_friendly_name=False, + keep_quantized_weight=False, ): """Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -3745,6 +3746,16 @@ def from_pytorch( so a variable name like "dense.weight" cannot be parsed correctly. Use this option when you want to run the AnnotateSpans pass on the imported module. + keep_quantized_weight : bool + Return quantized weights and bias, rather than float ones. PyTorch stores quantized weights + in a custom format, so we cannot directly access 8 bit weights as Numpy arrays. We use + a PyTorch function to unpack quantized weights into float32 arrays and quantization + parameters. By default, we return float32 weights and rely on the QNN lowering and the + Relay constant folding pass to quantize weights at compile time. In BYOC use cases, however, + we cannot apply the constant folding pass on a QNN graph. If keep_quantized_weight is True, + we quantize weights in the frontend using a function that is equivalent to + qnn.op.quantize(...) operating on Numpy arrays. + Returns ------- mod : tvm.IRModule @@ -3789,9 +3800,17 @@ def from_pytorch( # For quantized models quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"]) if len(quantized_ops.intersection(set(op_names))) > 0: - weight_quant_params = qnn_torch.get_weight_quant_params(script_module) - qnn_torch.add_input_quant_params_to_op_inputs(graph) - qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params) + weight_quant_params = qnn_torch.get_weight_quant_params( + script_module, packed_param_map.values() + ) + input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph) + qnn_torch.add_quant_params_to_outputs( + outputs, + packed_param_map, + weight_quant_params, + input_scales_for_bias, + keep_quantized_weight, + ) qnn_torch.add_quant_params(tvm_params, weight_quant_params) converter.update_convert_map(qnn_torch.convert_map) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 9eafae905baf..af3c352d15ae 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -32,16 +32,12 @@ class QNNParam: """A placeholder for weight quantization parameters""" - def __init__(self, weight, bias, scale, zero_point, param_key): - param_prefix = param_key[: -len("._packed_params")] - self.weight_var = _expr.var(param_prefix + "_weight", shape=weight.shape) + def __init__(self, weight, bias, scale, zero_point): self.weight = weight if bias is not None: - self.bias_var = _expr.var(param_prefix + "_bias", shape=bias.shape) self.bias = bias.detach().numpy() else: - self.bias_var = None self.bias = None self.scale = _expr.const(scale) @@ -55,10 +51,8 @@ class ConvPackedParam(QNNParam): together with weights and quantization parameters """ - def __init__( - self, weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups - ): - super().__init__(weight_np, bias, scale, zero_point, param_name) + def __init__(self, weight_np, bias, scale, zero_point, stride, padding, dilation, groups): + super().__init__(weight_np, bias, scale, zero_point) self.stride = stride self.padding = padding self.dilation = dilation @@ -81,23 +75,21 @@ def _get_quant_params(qweight): return weight_np, scales, 0 -def make_qnn_param(param_name, qweight, bias): +def make_qnn_param(qweight, bias): weight_np, scale, zero_point = _get_quant_params(qweight) - return QNNParam(weight_np, bias, scale, zero_point, param_name) + return QNNParam(weight_np, bias, scale, zero_point) -def make_conv_packed_param(param_name, qweight, bias, packed_params): +def make_conv_packed_param(qweight, bias, packed_params): weight_np, scale, zero_point = _get_quant_params(qweight) stride = packed_params.stride() padding = packed_params.padding() dilation = packed_params.dilation() groups = packed_params.groups() - return ConvPackedParam( - weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups - ) + return ConvPackedParam(weight_np, bias, scale, zero_point, stride, padding, dilation, groups) -def get_weight_quant_params(script_module): +def get_weight_quant_params(script_module, packed_param_names): """Retrive and unpack weight parameters from quantized modules""" import torch @@ -114,6 +106,9 @@ def filter_func(named_module): key = name + "." + param_name state_dict = m.state_dict() + if key not in packed_param_names: + continue + if len(state_dict) == 0 and not hasattr(m, param_name): # for v1.6 and above # This case seems to happen if a model is serialized @@ -130,28 +125,87 @@ def filter_func(named_module): if "Conv" in m.original_name and len(state_dict) == 0: qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params) - quant_params[key] = make_conv_packed_param(key, qweight, bias, packed_params) + quant_params[key] = make_conv_packed_param(qweight, bias, packed_params) elif "Conv" in m.original_name: qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params) - quant_params[key] = make_qnn_param(key, qweight, bias) + quant_params[key] = make_qnn_param(qweight, bias) elif m.original_name == "LinearPackedParams": qweight, bias = torch.ops.quantized.linear_unpack(packed_params) - quant_params[key] = make_qnn_param(key, qweight, bias) + quant_params[key] = make_qnn_param(qweight, bias) return quant_params -def add_quant_params_to_outputs(outputs, packed_param_map, quant_params): +def quantize_numpy(weight, scale, zero_point, out_dtype_np): + iinfo = np.iinfo(out_dtype_np) + clip_min = iinfo.min + clip_max = iinfo.max + if len(scale.shape) > 0: + scale = np.reshape(scale, [weight.shape[0]] + [1] * (len(weight.shape) - 1)) + transformed = zero_point + weight / scale + return np.clip(np.round(transformed), clip_min, clip_max).astype(out_dtype_np) + + +def add_quant_params_to_outputs( + outputs, packed_param_map, quant_params, input_scales_for_bias, keep_quantized_weight=False +): """ Add quant params to outputs so that they can be referenced by other ops later. Weights are quantized here. """ for node_name, packed_param_name in packed_param_map.items(): qparam = quant_params[packed_param_name] - qweight = relay.qnn.op.quantize( - qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0 - ) - params = [qweight, qparam.scale, qparam.zero_point, qparam.bias_var] + weight_scale = _get_numpy(qparam.scale) + param_prefix = packed_param_name[: -len("._packed_params")] + + if keep_quantized_weight: + qparam.weight_var = _expr.var( + param_prefix + "_weight", shape=qparam.weight.shape, dtype="int8" + ) + qparam.weight = quantize_numpy( + qparam.weight, weight_scale, _get_numpy(qparam.zero_point), np.int8 + ) + qweight = qparam.weight_var + else: + qparam.weight_var = _expr.var( + param_prefix + "_weight", shape=qparam.weight.shape, dtype="float32" + ) + qweight = relay.qnn.op.quantize( + qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0 + ) + + if qparam.bias is not None: + float_bias_var = _expr.var( + param_prefix + "_bias", shape=qparam.bias.shape, dtype="float32" + ) + if node_name not in input_scales_for_bias: + # This case is for dynamic quantization, where the input activation scale is + # unknown until runtime. + qparam.bias_var = float_bias_var + qbias = qparam.bias_var + elif keep_quantized_weight: + qparam.bias_var = _expr.var( + param_prefix + "_bias", shape=qparam.bias.shape, dtype="int32" + ) + qparam.bias = quantize_numpy( + qparam.bias, input_scales_for_bias[node_name] * weight_scale, 0, np.int32 + ) + qbias = qparam.bias_var + else: + qparam.bias_var = float_bias_var + qbias = relay.qnn.op.quantize( + qparam.bias_var, + _expr.const(input_scales_for_bias[node_name] * weight_scale), + _expr.const(0, "int32"), + out_dtype="int32", + axis=0, + ) + else: + qbias = None + + quant_params[packed_param_name] = qparam + + params = [qweight, qparam.scale, qparam.zero_point, qbias] if isinstance(quant_params[packed_param_name], ConvPackedParam): params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups] @@ -367,6 +421,8 @@ def add_input_quant_params_to_op_inputs(graph): need_input_quant_param = set(num_quantized_inputs.keys()) need_input_quant_param.add("quantized::cat") + input_scales_for_bias = {} + for node in graph.nodes(): operator = node.kind() if operator not in need_input_quant_param: @@ -401,6 +457,12 @@ def add_input_quant_params_to_op_inputs(graph): node.addInput(scale) node.addInput(zp) + if "conv2d" in operator or "linear" in operator: + # This is required for quantizing the bias + input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value") + + return input_scales_for_bias + def add_quant_params(params, quant_params): """Add quant parameters to TVM param map""" @@ -478,10 +540,7 @@ def _do_bias_and_requantize( # Instead, the torch way requires rounding of activation at runtime if bias is not None: - qbias = relay.qnn.op.quantize( - bias, requant_input_scale, _expr.const(0, "int32"), out_dtype="int32", axis=0 - ) - requantize_input = _op.nn.bias_add(output, qbias) + requantize_input = _op.nn.bias_add(output, bias) else: requantize_input = output diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 704245040025..65e5692dc4fb 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -40,9 +40,15 @@ def torch_version_check(): return version.parse(torch.__version__) > version.parse("1.4.0") -def get_tvm_runtime(script_module, input_name, ishape): +def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False): input_shapes = [(input_name, ishape)] - mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + mod, params = relay.frontend.from_pytorch( + script_module, input_shapes, keep_quantized_weight=keep_quantized_weight + ) + + if keep_quantized_weight: + for p in params.values(): + assert p.dtype in ["int8", "int32"] with tvm.transform.PassContext(opt_level=3): # test on only cpu for now, torch cannot run quant models on cuda @@ -609,3 +615,36 @@ def test_qnn_mergecomposite(): input_name = "image" run_qnn_mergecomposite(script_module, input_name, inp.shape) + + +def test_keep_quantized_weight(): + qmodules = [] + + for per_channel in [False, True]: + qmodules += [ + ((1, 3, 224, 224), ConvBn(), per_channel), + ((16, 16), Linear(), per_channel), + ] + + for (ishape, raw_module, per_channel) in qmodules: + raw_module.eval() + inp = torch.rand(ishape) + + quantize_model(raw_module, inp, per_channel=per_channel) + script_module = torch.jit.trace(raw_module, inp).eval() + + input_name = "input" + + runtime = get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False) + runtime.set_input(input_name, inp.numpy().copy()) + runtime.run() + tvm_result = runtime.get_output(0).numpy() + + runtime_int8_weight = get_tvm_runtime( + script_module, input_name, ishape, keep_quantized_weight=True + ) + runtime_int8_weight.set_input(input_name, inp.numpy().copy()) + runtime_int8_weight.run() + tvm_result_int8_weight = runtime_int8_weight.get_output(0).numpy() + + tvm.testing.assert_allclose(tvm_result, tvm_result_int8_weight)