Skip to content

Commit

Permalink
[TFLite] Support quantised SQUARED_DIFFERENCE operator
Browse files Browse the repository at this point in the history
Add support and test to the SQUARED_DIFFERENCE operator in
the TFLite frontend.

Co-Authored-By: Shai Maor <Shai.Maor@arm.com>
  • Loading branch information
leandron and shamao01 committed May 12, 2022
1 parent 588679e commit abc68f1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
20 changes: 12 additions & 8 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,7 @@ def convert_square(self, op):

return out

def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False):
def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False, dequantize=False):
"""Generic method to Convert TFLite elemwise"""
try:
from tflite.AddOptions import AddOptions
Expand Down Expand Up @@ -1254,8 +1254,13 @@ def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False):

# If quantized, extracts qnn params and call QNN add operator.
if not ignore_qnn_params and lhs_tensor.qnn_params:
assert rhs_tensor.qnn_params, "Both tensors should be quantized."
assert output_tensor.qnn_params, "Output tensor should be quantized."
if not dequantize:
assert rhs_tensor.qnn_params, "Both tensors should be quantized."
assert output_tensor.qnn_params, "Output tensor should be quantized."
else:
lhs_expr = self.dequantize(lhs_expr, lhs_tensor)
rhs_expr = self.dequantize(rhs_expr, rhs_tensor)

out = relay_op(
lhs=lhs_expr,
rhs=rhs_expr,
Expand All @@ -1269,6 +1274,9 @@ def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False):
else:
out = relay_op(lhs_expr, rhs_expr)

if dequantize and output_tensor.qnn_params:
out = self.quantize(out, output_tensor)

# Options (fused_activation_function)
options = None
if op.BuiltinOptionsType() == BuiltinOptions.AddOptions:
Expand Down Expand Up @@ -1370,11 +1378,7 @@ def convert_greater(self, op):
def convert_squared_difference(self, op):
"""Convert TFLite SQUARED DIFFERENCE"""
# Check if the input tensor is quantized, call QNN op
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
"TFlite quantized squared difference operator is not supported yet."
)
difference = self._convert_elemwise(_op.subtract, op)
difference = self._convert_elemwise(_op.subtract, op, dequantize=True)
# _convert_elemwise has guaranteed only have one output tensor
exp_type = self.get_tensor_type_str(self.get_output_tensors(op)[0].tensor.Type())
out = _op.power(difference, relay.const(2, exp_type))
Expand Down
18 changes: 16 additions & 2 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def compare_tflite_with_tvm(
input_range=None,
mode="graph_executor",
experimental_new_converter=False,
experimental_new_quantizer=False,
fp16_quantized=False,
):
"""Generic function to generate and compare TFLite and TVM output"""
Expand All @@ -286,6 +287,7 @@ def compare_tflite_with_tvm(
# convert to tflite model
converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors)
converter.experimental_new_converter = experimental_new_converter
converter.experimental_new_quantizer = experimental_new_quantizer
if quantized:
converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
input_arrays = converter.get_input_arrays()
Expand Down Expand Up @@ -2076,6 +2078,7 @@ def _test_elemwise(
quantized=False,
qnn_op=None,
same_qnn_params=False,
experimental_new_quantizer=True,
):
"""One iteration of elemwise"""

Expand Down Expand Up @@ -2135,6 +2138,7 @@ def __test_elemwise(in_data):
quantized=True,
input_range=input_range,
experimental_new_converter=same_qnn_params,
experimental_new_quantizer=experimental_new_quantizer,
)
else:
out = math_op(
Expand Down Expand Up @@ -2312,9 +2316,17 @@ def _test_not_equal(data):
# ------------------


def _test_squared_difference(data):
def _test_squared_difference(data, fused_activation_function=None, quantized=False, qnn_op=None):
"""One iteration of squared difference"""
return _test_elemwise(math_ops.squared_difference, data)
return _test_elemwise(
math_ops.squared_difference,
data,
fused_activation_function,
quantized,
qnn_op,
same_qnn_params=True,
experimental_new_quantizer=False,
)


#######################################################################
Expand Down Expand Up @@ -2378,6 +2390,7 @@ def _test_elemwise_qnn_out_range(qnn_op):
_test_mul: (-5e3, 5e3),
_test_maximum: (-112, 111),
_test_minimum: (-128, 127),
_test_squared_difference: (0, 225e2),
}

return qnn_out_range[qnn_op]
Expand Down Expand Up @@ -2408,6 +2421,7 @@ def test_all_elemwise():
_test_forward_elemwise_quantized(_test_minimum)
_test_forward_elemwise(_test_greater)
_test_forward_elemwise(_test_squared_difference)
_test_forward_elemwise_quantized(_test_squared_difference)
_test_forward_elemwise(_test_greater_equal)
_test_forward_elemwise(_test_less)
_test_forward_elemwise(_test_less_equal)
Expand Down

0 comments on commit abc68f1

Please sign in to comment.