Skip to content

Commit

Permalink
[Frontend][PyTorch] support for quantized conv_transpose2d op
Browse files Browse the repository at this point in the history
PyTorch uses the same underlying function to pack and
unpack the params for conv2d and conv_transpose2d ops.

This change adds support for quantized conv_transpose2d op
by reusing the ConvPackedParam and adding the
output_padding param to it.
This output_padding param will remain unused in case of conv2d.

Also added test for above with specific condition for
torch v1.7.1 and below.
  • Loading branch information
abraham-arun committed Sep 28, 2021
1 parent 4251103 commit ae0da6b
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 4 deletions.
96 changes: 93 additions & 3 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,25 @@ class ConvPackedParam(QNNParam):
"""

def __init__(
self, weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups
self,
weight_np,
bias,
scale,
zero_point,
param_name,
stride,
padding,
dilation,
groups,
output_padding,
):
super().__init__(weight_np, bias, scale, zero_point, param_name)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
# Used only for conv_transpose2d
self.output_padding = output_padding


def _get_quant_params(qweight):
Expand Down Expand Up @@ -92,8 +104,18 @@ def make_conv_packed_param(param_name, qweight, bias, packed_params):
padding = packed_params.padding()
dilation = packed_params.dilation()
groups = packed_params.groups()
output_padding = packed_params.output_padding()
return ConvPackedParam(
weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups
weight_np,
bias,
scale,
zero_point,
param_name,
stride,
padding,
dilation,
groups,
output_padding,
)


Expand Down Expand Up @@ -154,7 +176,13 @@ def add_quant_params_to_outputs(outputs, packed_param_map, quant_params):
params = [qweight, qparam.scale, qparam.zero_point, qparam.bias_var]

if isinstance(quant_params[packed_param_name], ConvPackedParam):
params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups]
params += [
qparam.stride,
qparam.padding,
qparam.dilation,
qparam.groups,
qparam.output_padding,
]

outputs[node_name] = params

Expand Down Expand Up @@ -192,6 +220,7 @@ def _get_quant_param_for_input(input_value):
"quantized::mul_scalar": (2, 3),
"quantized::add_scalar": (2, 3),
"quantized::hardswish": (1, 2),
"quantized::conv_transpose2d": qconv_indices,
}

def dfs(current_node):
Expand Down Expand Up @@ -362,6 +391,7 @@ def add_input_quant_params_to_op_inputs(graph):
"quantized::relu6": 1,
"quantized::hardswish": 1,
"aten::hardsigmoid": 1,
"quantized::conv_transpose2d": 1,
}

need_input_quant_param = set(num_quantized_inputs.keys())
Expand Down Expand Up @@ -924,6 +954,65 @@ def _impl(inputs, _):
return _impl


def _quantized_conv_transpose2d(with_relu=False):
def _impl(inputs, _):
# Refer to aten/src/ATen/native/quantized/cpu/qconv.cpp
# Supported in Torch 1.7 or newer
conv_params = inputs[1]
weight = conv_params[0]
weight_scale = conv_params[1]
weight_zero_point = conv_params[2]
bias = conv_params[3]

strides = conv_params[4]
padding = conv_params[5]
dilation = conv_params[6]
groups = conv_params[7]
output_padding = conv_params[8]

output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])

assert len(inputs) == 6, "Input quant params not found in op inputs"

# These are manually added by add_input_quant_params_to_op_inputs above
# In torch, they are retrieved from QTensor data structure at runtime
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])

weight_shape = list(infer_shape(weight))

# Swap I and O dims to match shape relay expects for OIHW
weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]

kernel_size = (weight_shape[2], weight_shape[3])
out_channels = weight_shape[0]

conv_out = relay.qnn.op.conv2d_transpose(
inputs[0],
weight,
input_zero_point,
weight_zero_point,
input_scale,
weight_scale,
kernel_size=kernel_size,
dilation=dilation,
strides=strides,
padding=padding,
groups=groups,
channels=out_channels,
output_padding=output_padding,
out_dtype="int32",
kernel_layout="OIHW",
)

return _do_bias_and_requantize(
conv_out, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu
)

return _impl


convert_map = {
"aten::quantize_per_tensor": _quantize_per_tensor(),
"quantized::conv2d_relu": _quantized_conv2d(with_relu=True),
Expand All @@ -941,4 +1030,5 @@ def _impl(inputs, _):
"quantized::relu6": _relu6(),
"quantized::linear_dynamic": _linear_dynamic(),
"quantized::hardswish": _hswish(),
"quantized::conv_transpose2d": _quantized_conv_transpose2d(),
}
26 changes: 25 additions & 1 deletion tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ def fuse_model(self):
fuse_modules(self.conv, indices, inplace=True)


class ConvTranspose(nn.Module):
def __init__(self):
super().__init__()
layers = [nn.ConvTranspose2d(3, 32, 3, bias=True)]
self.conv = nn.Sequential(*layers)
self.quant_wrap = QuantWrapper(self.conv)

def forward(self, x):
return self.quant_wrap(x)

def fuse_model(self):
pass


class Linear(nn.Module):
def __init__(self, with_relu=False):
super().__init__()
Expand Down Expand Up @@ -270,6 +284,7 @@ def test_quantized_modules():
("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
("linear" + postfix, (16, 16), Linear(), per_channel),
("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel),
("conv_transpose", imagenet_ishape, ConvTranspose(), False),
("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
("hswish", imagenet_ishape, Hswish(add_stub=True), False),
("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
Expand All @@ -281,7 +296,15 @@ def test_quantized_modules():
raw_module.eval()
inp = torch.rand(ishape)

quantize_model(raw_module, inp, per_channel=per_channel)
# quantized conv_transpose2d is supported only with qnnpack engine before torch v1.8.0.
if module_name == "conv_transpose" and not is_version_greater_than("1.7.1"):
prev_engine = torch.backends.quantized.engine
torch.backends.quantized.engine = "qnnpack"
quantize_model(raw_module, inp, per_channel=per_channel)
torch.backends.quantized.engine = prev_engine
else:
quantize_model(raw_module, inp, per_channel=per_channel)

script_module = torch.jit.trace(raw_module, inp).eval()

with torch.no_grad():
Expand All @@ -308,6 +331,7 @@ def test_quantized_modules():
conv_bn_relu 0.3700896 0.010921672 0.7489366477964451
linear 0.15987062 0.009231662 0.794921875
linear_relu 0.14180502 0.0053220326 0.8828125
conv_transpose 0.0033792555 4.4658788e-07 0.9998678439971806
conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019
conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732
linear, per_channel 0.0 0.0 1.0
Expand Down

0 comments on commit ae0da6b

Please sign in to comment.