Skip to content

Commit

Permalink
[TVM] Add importer for ONNX QLinearMatMul op
Browse files Browse the repository at this point in the history
 * adds importer code

 * enables `test_qlinearmatmul_2D` unit test
  • Loading branch information
cconvey committed Oct 6, 2021
1 parent 627e92e commit 55d4c20
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 5 deletions.
149 changes: 145 additions & 4 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def get_scalar(x, params, dtype="float32"):
x = _op.squeeze(x, [0])
return _op.cast(x, dtype)


class OnnxOpConverter(object):
"""A helper class for holding onnx op converters."""

Expand Down Expand Up @@ -3506,6 +3505,150 @@ def _impl_v10(cls, inputs, attr, params):
return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype)


class QLinearMatMul(OnnxOpConverter):
"""
Operator converter for QLinearMatMul from Microsoft onnxruntime contrib opset.
Limitations:
- Only supports 2D input tensors.
- Not guaranteed to meet the integer-overflow behavior stipulated in the
ONNX documentation for this operator.
"""

@classmethod
def _impl_v10(cls, inputs, attr, params):

# We can't necessarily anticipate which TVM backends will require certain
# operands to be simple Relay Constant nodes.
#
# There's no guarantee that the specific, numerical value of an input
# will be available at graph-compilation time. So we'll make a
# best-effort attempt to obtain simple-Constant form here, but our code
# below needs to allow for the possibility that it wasn't achieved.
def try_convert_to_Constant(x, dtype_override=None):
if isinstance(x, _expr.Var) and x.name_hint in params:
return _op.const(params[x.name_hint].numpy(), dtype)

rank = len(infer_shape(x))
if rank == 0:
x_scalar = x
return x
elif rank == 1:
x_scalar = _op.squeeze(x, [0])
else:
assert false, "op parameter '{}' must be scalar".format(x.name_hint)

if dtype_override is not None:
return fold_constant( _op.cast(x_scalar, dtype_override))
else:
return fold_constant(x_scalar)

# Unpack the inputs and obtain some type info...
a, a_scale, a_zp, b, b_scale, b_zp, y_scale, y_zp = inputs

a_type = infer_type(a).checked_type # 'T1' in ONNX doc for this op
a_scale_type = infer_type(a_scale).checked_type
a_zp_type = infer_type(a_zp).checked_type

b_type = infer_type(b).checked_type # 'T2' in ONNX doc for this op
b_scale_type = infer_type(b_scale).checked_type
b_zp_type = infer_type(b_zp).checked_type

y_scale_type = infer_type(y_scale).checked_type
y_zp_type = infer_type(y_zp).checked_type # 'T3' in ONNX doc for this op

a_shape = infer_shape(a)
b_shape = infer_shape(b)

# Verify type assumptions, based on the ONNX doc for this op...
assert a_type.dtype in ['int8', 'uint8']
assert a_scale_type.dtype == 'float32'
assert a_zp_type.dtype == a_type.dtype

assert b_type.dtype in ['int8', 'uint8']
assert b_scale_type.dtype == 'float32'
assert b_zp_type.dtype == b_type.dtype

assert y_scale_type.dtype == 'float32'
assert y_zp_type.dtype in ['int8', 'uint8']

# TODO: relax this limitation in a future version of this importer.
a_rank = len(a_shape)
b_rank = len(b_shape)
assert (a_rank == 2) and (
b_rank == 2
), "QLinearMatMul importer currently requires both 'a' and 'b' tensors to be 2D, but rank(a)={}, rank(b)={}".format(
a_rank, b_rank
)

# _qnn.op.dense requires the zero-point values to have dtype int32.
a_scale_scalar = try_convert_to_Constant(a_scale)
a_zp_scalar = try_convert_to_Constant(a_zp, 'int32')

b_scale_scalar = try_convert_to_Constant(b_scale)
b_zp_scalar = try_convert_to_Constant(b_zp, 'int32')

y_scale_scalar = try_convert_to_Constant(y_scale)
y_zp_scalar = try_convert_to_Constant(y_zp, 'int32')

# TODO: Confirm that we're using 'num_hidden_units' correctly / as intended with
# the '_qnn.op.dense' instance below.
num_hidden_units = infer_shape(b)[-1]

# - Specify the matmul result dtype as int32, so that hopefully the matmul will use
# a 32-bit accumulator as seems to be required by the ONNX op's documentation.
#
# TL;DR:
# The ONNX documentation for this op is clear about acceptable overflow
# behavior during the matmul operation:
# - The scalar multiplication ops MAY NOT overflow.
# - The scalar addition ops, which sum the results of the scalar multiplication,
# MAY overflow, but if they do so, it must behave as one would expect during
# 32-bit integer-addition overflow.
# As of this writing, Relay's qnn.op.dense operator doesn't expose a way for us to
# express these constraints.
matmul_result_dtype = "int32"

matmul_result =_qnn.op.dense(a, _op.transpose(b), a_zp_scalar, b_zp_scalar,
a_scale_scalar, b_scale_scalar, num_hidden_units, matmul_result_dtype)

# This information might only be found in the C++ code-comments for the
# dense.matmul op, but the quantized tensor returned by _qnn.op.dense
# has scale==(a_scale_scalar * b_scale_scalar), and zero_point==0.
#
# 'matmul_result_zp_scalar' has type 'int32' to satisfy input requirements
# of the [de/re]quantize ops below.
matmul_result_scale_scalar = fold_constant(_op.multiply(a_scale_scalar, b_scale_scalar))
matmul_result_zp_scalar = _op.const(0, dtype="int32")

# requantize requires y_scale to be constant,
# if y_scale is not constant, doing dequantize -> quantize
if isinstance(y_scale_scalar, _expr.Constant):
y = _qnn.op.requantize(
matmul_result,
matmul_result_scale_scalar,
matmul_result_zp_scalar,
y_scale_scalar,
y_zp_scalar,
axis=-1,
rounding="TONEAREST",
out_dtype=y_zp_type.dtype)
else:
matmul_result_deq = _qnn.op.dequantize(
matmul_result,
matmul_result_scale_scalar,
matmul_result_zp_scalar,
axis=0)

y = _qnn.op.quantize(
matmul_result_deq,
y_scale_scalar,
y_zp_scalar,
axis=0,
out_dtype=y_zp_type.dtype)

return y

class QLinearMul(OnnxOpConverter):
"""Operator converter for QLinearMul from Microsoft onnxruntime contrib opset."""

Expand All @@ -3522,9 +3665,6 @@ def _impl_v10(cls, inputs, attr, params):

dtype = infer_type(a).checked_type.dtype

## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32
## and then requantize afer
## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qlmul.cpp
a = _qnn.op.dequantize(inputs[0], a_scale, a_zero_point)
b = _qnn.op.dequantize(inputs[3], b_scale, b_zero_point)
out = _op.multiply(a, b)
Expand Down Expand Up @@ -4234,6 +4374,7 @@ def _get_convert_map(opset):
"QLinearConv": QLinearConv.get_converter(opset),
"QLinearConcat": QLinearConcat.get_converter(opset),
"QLinearAdd": QLinearAdd.get_converter(opset),
"QLinearMatMul": QLinearMatMul.get_converter(opset),
"QLinearMul": QLinearMul.get_converter(opset),
"QLinearSigmoid": QLinearSigmoid.get_converter(opset),
"ConvInteger": ConvInteger.get_converter(opset),
Expand Down
1 change: 0 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4941,7 +4941,6 @@ def verify_eyelike(indata):
"test_mvn",
# This test fails llvm with a lowering error:
"test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded",
"test_qlinearmatmul_2D",
"test_qlinearmatmul_3D",
"test_range_float_type_positive_delta_expanded",
"test_range_int32_type_negative_delta_expanded",
Expand Down

0 comments on commit 55d4c20

Please sign in to comment.