diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 9eafae905baf..13c2bd17ce57 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -56,13 +56,25 @@ class ConvPackedParam(QNNParam): """ def __init__( - self, weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups + self, + weight_np, + bias, + scale, + zero_point, + param_name, + stride, + padding, + dilation, + groups, + output_padding, ): super().__init__(weight_np, bias, scale, zero_point, param_name) self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups + # Used only for conv_transpose2d + self.output_padding = output_padding def _get_quant_params(qweight): @@ -92,8 +104,18 @@ def make_conv_packed_param(param_name, qweight, bias, packed_params): padding = packed_params.padding() dilation = packed_params.dilation() groups = packed_params.groups() + output_padding = packed_params.output_padding() return ConvPackedParam( - weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups + weight_np, + bias, + scale, + zero_point, + param_name, + stride, + padding, + dilation, + groups, + output_padding, ) @@ -154,7 +176,13 @@ def add_quant_params_to_outputs(outputs, packed_param_map, quant_params): params = [qweight, qparam.scale, qparam.zero_point, qparam.bias_var] if isinstance(quant_params[packed_param_name], ConvPackedParam): - params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups] + params += [ + qparam.stride, + qparam.padding, + qparam.dilation, + qparam.groups, + qparam.output_padding, + ] outputs[node_name] = params @@ -192,6 +220,7 @@ def _get_quant_param_for_input(input_value): "quantized::mul_scalar": (2, 3), "quantized::add_scalar": (2, 3), "quantized::hardswish": (1, 2), + "quantized::conv_transpose2d": qconv_indices, } def dfs(current_node): @@ -362,6 +391,7 @@ def add_input_quant_params_to_op_inputs(graph): "quantized::relu6": 1, "quantized::hardswish": 1, "aten::hardsigmoid": 1, + "quantized::conv_transpose2d": 1, } need_input_quant_param = set(num_quantized_inputs.keys()) @@ -924,6 +954,65 @@ def _impl(inputs, _): return _impl +def _quantized_conv_transpose2d(with_relu=False): + def _impl(inputs, _): + # Refer to aten/src/ATen/native/quantized/cpu/qconv.cpp + # Supported in Torch 1.7 or newer + conv_params = inputs[1] + weight = conv_params[0] + weight_scale = conv_params[1] + weight_zero_point = conv_params[2] + bias = conv_params[3] + + strides = conv_params[4] + padding = conv_params[5] + dilation = conv_params[6] + groups = conv_params[7] + output_padding = conv_params[8] + + output_scale = _expr.const(inputs[2]) + output_zero_point = _expr.const(inputs[3]) + + assert len(inputs) == 6, "Input quant params not found in op inputs" + + # These are manually added by add_input_quant_params_to_op_inputs above + # In torch, they are retrieved from QTensor data structure at runtime + input_scale = _expr.const(inputs[4]) + input_zero_point = _expr.const(inputs[5]) + + weight_shape = list(infer_shape(weight)) + + # Swap I and O dims to match shape relay expects for OIHW + weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0] + + kernel_size = (weight_shape[2], weight_shape[3]) + out_channels = weight_shape[0] + + conv_out = relay.qnn.op.conv2d_transpose( + inputs[0], + weight, + input_zero_point, + weight_zero_point, + input_scale, + weight_scale, + kernel_size=kernel_size, + dilation=dilation, + strides=strides, + padding=padding, + groups=groups, + channels=out_channels, + output_padding=output_padding, + out_dtype="int32", + kernel_layout="OIHW", + ) + + return _do_bias_and_requantize( + conv_out, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu + ) + + return _impl + + convert_map = { "aten::quantize_per_tensor": _quantize_per_tensor(), "quantized::conv2d_relu": _quantized_conv2d(with_relu=True), @@ -941,4 +1030,5 @@ def _impl(inputs, _): "quantized::relu6": _relu6(), "quantized::linear_dynamic": _linear_dynamic(), "quantized::hardswish": _hswish(), + "quantized::conv_transpose2d": _quantized_conv_transpose2d(), } diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 704245040025..3bd61dea644c 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -92,6 +92,20 @@ def fuse_model(self): fuse_modules(self.conv, indices, inplace=True) +class ConvTranspose(nn.Module): + def __init__(self): + super().__init__() + layers = [nn.ConvTranspose2d(3, 32, 3, bias=True)] + self.conv = nn.Sequential(*layers) + self.quant_wrap = QuantWrapper(self.conv) + + def forward(self, x): + return self.quant_wrap(x) + + def fuse_model(self): + pass + + class Linear(nn.Module): def __init__(self, with_relu=False): super().__init__() @@ -270,6 +284,7 @@ def test_quantized_modules(): ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel), ("linear" + postfix, (16, 16), Linear(), per_channel), ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel), + ("conv_transpose", imagenet_ishape, ConvTranspose(), False), ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False), ("hswish", imagenet_ishape, Hswish(add_stub=True), False), ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False), @@ -281,7 +296,15 @@ def test_quantized_modules(): raw_module.eval() inp = torch.rand(ishape) - quantize_model(raw_module, inp, per_channel=per_channel) + # quantized conv_transpose2d is supported only with qnnpack engine before torch v1.8.0. + if module_name == "conv_transpose" and not is_version_greater_than("1.7.1"): + prev_engine = torch.backends.quantized.engine + torch.backends.quantized.engine = "qnnpack" + quantize_model(raw_module, inp, per_channel=per_channel) + torch.backends.quantized.engine = prev_engine + else: + quantize_model(raw_module, inp, per_channel=per_channel) + script_module = torch.jit.trace(raw_module, inp).eval() with torch.no_grad(): @@ -308,6 +331,7 @@ def test_quantized_modules(): conv_bn_relu 0.3700896 0.010921672 0.7489366477964451 linear 0.15987062 0.009231662 0.794921875 linear_relu 0.14180502 0.0053220326 0.8828125 + conv_transpose 0.0033792555 4.4658788e-07 0.9998678439971806 conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019 conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732 linear, per_channel 0.0 0.0 1.0