Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for quantized transpose convolution in Relay #6899

Merged
merged 3 commits into from
Nov 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 59 additions & 13 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2809,7 +2809,7 @@ def convert_transpose_conv(self, op):
# Weights
weights_tensor_type = weights_tensor.tensor.Type()
# weights tensor type should be UINT8 (quantization) or FLOAT32
assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
assert weights_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type)
weight_value_ohwi = self.get_tensor_value(weights_tensor)
# Relay kernel_layout should be OIHW
Expand All @@ -2831,19 +2831,40 @@ def convert_transpose_conv(self, op):
else:
padding = (0, 0, 0, 0)

out = _op.nn.conv2d_transpose(
in_expr,
weight_expr_iohw,
strides=(stride_h, stride_w),
padding=padding,
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
out_dtype=output_tensor_type_str,
)
if input_tensor.qnn_params:
input_zero_point = input_tensor.qnn_params["zero_point"]
kernel_zero_point = weights_tensor.qnn_params["zero_point"]
input_scale = input_tensor.qnn_params["scale"]
kernel_scale = weights_tensor.qnn_params["scale"]
out = _qnn.op.conv2d_transpose(
in_expr,
weight_expr_iohw,
input_zero_point,
kernel_zero_point,
input_scale,
kernel_scale,
strides=(stride_h, stride_w),
padding=padding,
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
out_dtype="int32",
)
else:
out = _op.nn.conv2d_transpose(
in_expr,
weight_expr_iohw,
strides=(stride_h, stride_w),
padding=padding,
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
out_dtype=output_tensor_type_str,
)

# if we have bias
# Checking if there is a fused bias
if len(input_tensors) == 4:
bias_tensor = input_tensors[3]
bias_tensor_type = bias_tensor.tensor.Type()
Expand All @@ -2856,6 +2877,31 @@ def convert_transpose_conv(self, op):
channel_axis = 3
out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)

if output_tensor.qnn_params:
# Calculate the intermediate scale and zero point of the int32 output.
data_scale = input_tensor.qnn_params["scale"]
data_scale_val = get_scalar_from_constant(data_scale)

weight_scale = weights_tensor.qnn_params["scale"]
# If weight scale is scalar, it is per-tensor quantization
if isinstance(weight_scale, float):
weight_scale_val = get_scalar_from_constant(weight_scale)
else:
weight_scale_val = get_tensor_from_constant(weight_scale)

new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, "float32")
new_input_zero_point = relay.const(0, "int32")

out = _qnn.op.requantize(
out,
input_scale=new_input_scale,
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params["scale"],
output_zero_point=output_tensor.qnn_params["zero_point"],
out_dtype=output_tensor_type_str,
axis=3,
)
return out

def convert_quantize(self, op):
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def legalize_qnn_conv2d(attrs, inputs, types):
return qnn_conv2d_legalize(attrs, inputs, types)


# Registering QNN Conv2DTranspose legalization function.
@reg.register_qnn_legalize("qnn.conv2d_transpose")
def legalize_qnn_conv2d_transpose(attrs, inputs, types):
return qnn_conv2d_transpose_legalize(attrs, inputs, types)


# Registering QNN dense legalization function.
@reg.register_qnn_legalize("qnn.dense")
def legalize_qnn_dense(attrs, inputs, types):
Expand All @@ -46,6 +52,24 @@ def qnn_conv2d_legalize(attrs, inputs, types):
return None


# Generic QNN Conv2DTranspose legalization function.
@tvm.target.generic_func
def qnn_conv2d_transpose_legalize(attrs, inputs, types):
"""Convert kernel and data to int16, subtract offsets upfront
and calls into relay.nn.conv2d_transpose."""

# Collect the input exprs.
data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs

shift_data = relay.subtract(
relay.cast(data, dtype="int16"), relay.cast(input_zero_point, "int16")
)
shift_kernel = relay.subtract(
relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, "int16")
)
return relay.nn.conv2d_transpose(shift_data, shift_kernel, **attrs)


# Generic QNN Conv2D legalization function.
@tvm.target.generic_func
def qnn_dense_legalize(attrs, inputs, types):
Expand Down
112 changes: 112 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,118 @@ def conv2d(
)


def conv2d_transpose(
data,
weight,
input_zero_point,
kernel_zero_point,
input_scale,
kernel_scale,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
output_padding=(0, 0),
out_dtype="",
):
"""This operator deconvolves quantized data with quantized kernel. The scale of
the output quantized tensor is the product of the kernel_scale and
input_scale of the input quantized tensors. The zero point of the output
quantized tensor is 0. By default, the dtype of output is int32. Please also
refer to Requantize operator to understand how to scale back the int32
output to (u)int8.

Parameters
giuseros marked this conversation as resolved.
Show resolved Hide resolved
----------
data : tvm.relay.Expr
The input data to the operator.

weight : tvm.relay.Expr
The weight expressions.

input_zero_point: tvm.relay.Expr
The zero point of the data distribution.

kernel_zero_point: tvm.relay.Expr
The zero point of the quantized_kernel distribution.

input_scale: tvm.relay.Expr
The scale for the input tensor. The scale for the input tensor is
stored purely for convenience here. See more commentary below.

kernel_scale: tvm.relay.Expr
The scale for the weight tensor. The scale for the weight tensor is
stored for access to this during relay. This information is not
needed in the pass pipeline after qnn.transpose_conv2d is lowered to the
sequence of steps as in nn.transpose_conv2d. See also input_scale in Requantize.

strides : Tuple[int], optional
The strides of convolution.

padding : Tuple[int], optional
The padding of convolution.

dilation : Tuple[int], optional
Specifies the dilation rate to be used for dilated convolution.

channels : int, optional
Number of output channels of this convolution.

kernel_size : tuple of int, optional
The spatial dimensions of the convolution kernel.

groups : int, optional
Number of groups for grouped convolution.

data_layout : str, optional
Layout of the input.

kernel_layout : str, optional
Layout of the weight.

out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout

output_padding : Tuple[int], optional
Used to identify the padding within the output shape
(only used in training, where transpose_conv represents the gradient of a convolution )

out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.conv2d_transpose(
data,
weight,
input_zero_point,
kernel_zero_point,
input_scale,
kernel_scale,
strides,
padding,
dilation,
groups,
channels,
kernel_size,
data_layout,
kernel_layout,
out_layout,
output_padding,
out_dtype,
)


def add(
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
):
Expand Down
Loading