Skip to content

Commit

Permalink
[Frontend][PyTorch] support for quantized conv_transpose2d op (apache…
Browse files Browse the repository at this point in the history
…#9133)

* [Frontend][PyTorch] support for quantized conv_transpose2d op

PyTorch uses the same underlying function to pack and
unpack the params for conv2d and conv_transpose2d ops.

This change adds support for quantized conv_transpose2d op
by reusing the ConvPackedParam and adding the
output_padding param to it.
This output_padding param will remain unused in case of conv2d.

Also added test for above with specific condition for
torch v1.7.1 and below.

* fix after merging main
  • Loading branch information
abraham-arun authored and ylc committed Jan 13, 2022
1 parent 5b59563 commit 589692d
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 5 deletions.
100 changes: 96 additions & 4 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,25 @@ class ConvPackedParam(QNNParam):
together with weights and quantization parameters
"""

def __init__(self, weight_np, bias, scale, zero_point, stride, padding, dilation, groups):
def __init__(
self,
weight_np,
bias,
scale,
zero_point,
stride,
padding,
dilation,
groups,
output_padding,
):
super().__init__(weight_np, bias, scale, zero_point)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
# Used only for conv_transpose2d
self.output_padding = output_padding


def _get_quant_params(qweight):
Expand Down Expand Up @@ -86,7 +99,18 @@ def make_conv_packed_param(qweight, bias, packed_params):
padding = packed_params.padding()
dilation = packed_params.dilation()
groups = packed_params.groups()
return ConvPackedParam(weight_np, bias, scale, zero_point, stride, padding, dilation, groups)
output_padding = packed_params.output_padding()
return ConvPackedParam(
weight_np,
bias,
scale,
zero_point,
stride,
padding,
dilation,
groups,
output_padding,
)


def get_weight_quant_params(script_module, packed_param_names):
Expand Down Expand Up @@ -208,7 +232,13 @@ def add_quant_params_to_outputs(
params = [qweight, qparam.scale, qparam.zero_point, qbias]

if isinstance(quant_params[packed_param_name], ConvPackedParam):
params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups]
params += [
qparam.stride,
qparam.padding,
qparam.dilation,
qparam.groups,
qparam.output_padding,
]

outputs[node_name] = params

Expand Down Expand Up @@ -246,6 +276,7 @@ def _get_quant_param_for_input(input_value):
"quantized::mul_scalar": (2, 3),
"quantized::add_scalar": (2, 3),
"quantized::hardswish": (1, 2),
"quantized::conv_transpose2d": qconv_indices,
}

def dfs(current_node):
Expand Down Expand Up @@ -416,6 +447,7 @@ def add_input_quant_params_to_op_inputs(graph):
"quantized::relu6": 1,
"quantized::hardswish": 1,
"aten::hardsigmoid": 1,
"quantized::conv_transpose2d": 1,
}

need_input_quant_param = set(num_quantized_inputs.keys())
Expand Down Expand Up @@ -457,7 +489,7 @@ def add_input_quant_params_to_op_inputs(graph):
node.addInput(scale)
node.addInput(zp)

if "conv2d" in operator or "linear" in operator:
if "conv" in operator or "linear" in operator:
# This is required for quantizing the bias
input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value")

Expand Down Expand Up @@ -983,6 +1015,65 @@ def _impl(inputs, _):
return _impl


def _quantized_conv_transpose2d(with_relu=False):
def _impl(inputs, _):
# Refer to aten/src/ATen/native/quantized/cpu/qconv.cpp
# Supported in Torch 1.7 or newer
conv_params = inputs[1]
weight = conv_params[0]
weight_scale = conv_params[1]
weight_zero_point = conv_params[2]
bias = conv_params[3]

strides = conv_params[4]
padding = conv_params[5]
dilation = conv_params[6]
groups = conv_params[7]
output_padding = conv_params[8]

output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])

assert len(inputs) == 6, "Input quant params not found in op inputs"

# These are manually added by add_input_quant_params_to_op_inputs above
# In torch, they are retrieved from QTensor data structure at runtime
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])

weight_shape = list(infer_shape(weight))

# Swap I and O dims to match shape relay expects for OIHW
weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]

kernel_size = (weight_shape[2], weight_shape[3])
out_channels = weight_shape[0]

conv_out = relay.qnn.op.conv2d_transpose(
inputs[0],
weight,
input_zero_point,
weight_zero_point,
input_scale,
weight_scale,
kernel_size=kernel_size,
dilation=dilation,
strides=strides,
padding=padding,
groups=groups,
channels=out_channels,
output_padding=output_padding,
out_dtype="int32",
kernel_layout="OIHW",
)

return _do_bias_and_requantize(
conv_out, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu
)

return _impl


convert_map = {
"aten::quantize_per_tensor": _quantize_per_tensor(),
"quantized::conv2d_relu": _quantized_conv2d(with_relu=True),
Expand All @@ -1000,4 +1091,5 @@ def _impl(inputs, _):
"quantized::relu6": _relu6(),
"quantized::linear_dynamic": _linear_dynamic(),
"quantized::hardswish": _hswish(),
"quantized::conv_transpose2d": _quantized_conv_transpose2d(),
}
26 changes: 25 additions & 1 deletion tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ def fuse_model(self):
fuse_modules(self.conv, indices, inplace=True)


class ConvTranspose(nn.Module):
def __init__(self):
super().__init__()
layers = [nn.ConvTranspose2d(3, 32, 3, bias=True)]
self.conv = nn.Sequential(*layers)
self.quant_wrap = QuantWrapper(self.conv)

def forward(self, x):
return self.quant_wrap(x)

def fuse_model(self):
pass


class Linear(nn.Module):
def __init__(self, with_relu=False):
super().__init__()
Expand Down Expand Up @@ -276,6 +290,7 @@ def test_quantized_modules():
("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
("linear" + postfix, (16, 16), Linear(), per_channel),
("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel),
("conv_transpose", imagenet_ishape, ConvTranspose(), False),
("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
("hswish", imagenet_ishape, Hswish(add_stub=True), False),
("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
Expand All @@ -287,7 +302,15 @@ def test_quantized_modules():
raw_module.eval()
inp = torch.rand(ishape)

quantize_model(raw_module, inp, per_channel=per_channel)
# quantized conv_transpose2d is supported only with qnnpack engine before torch v1.8.0.
if module_name == "conv_transpose" and not is_version_greater_than("1.7.1"):
prev_engine = torch.backends.quantized.engine
torch.backends.quantized.engine = "qnnpack"
quantize_model(raw_module, inp, per_channel=per_channel)
torch.backends.quantized.engine = prev_engine
else:
quantize_model(raw_module, inp, per_channel=per_channel)

script_module = torch.jit.trace(raw_module, inp).eval()

with torch.no_grad():
Expand All @@ -314,6 +337,7 @@ def test_quantized_modules():
conv_bn_relu 0.3700896 0.010921672 0.7489366477964451
linear 0.15987062 0.009231662 0.794921875
linear_relu 0.14180502 0.0053220326 0.8828125
conv_transpose 0.0033792555 4.4658788e-07 0.9998678439971806
conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019
conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732
linear, per_channel 0.0 0.0 1.0
Expand Down

0 comments on commit 589692d

Please sign in to comment.