Skip to content

Commit

Permalink
[TFLite] Support quantized GREATER op in TFLite frontend
Browse files Browse the repository at this point in the history
Support GREATER quantization operation conversion as part of issue apache#9187
Continuation of apache#11519
  • Loading branch information
dchauhan-arm committed Sep 9, 2022
1 parent 1d32c40 commit 2c2c70f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
19 changes: 10 additions & 9 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,13 @@ def convert_square(self, op):

return out

def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False):
def _convert_elemwise(
self,
relay_op,
op,
ignore_qnn_params=False,
comparison_op=False,
):
"""Generic method to Convert TFLite elemwise"""
try:
from tflite.AddOptions import AddOptions
Expand All @@ -1316,7 +1322,7 @@ def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False):

# TFLite format demands equal scale and zero_point tuple parameters for some operations
# to allow us to use non-quantized operation instead of quantized if ignore_qnn_params=True
if ignore_qnn_params:
if ignore_qnn_params and not comparison_op:
assert (
lhs_tensor.qnn_params
and self.has_same_qnn_params(lhs_tensor, output_tensor)
Expand Down Expand Up @@ -1431,12 +1437,7 @@ def convert_minimum(self, op):

def convert_greater(self, op):
"""Convert TFLite GREATER"""
# Check if the input tensor is quantized, call QNN op
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
"TFlite quantized GREATER operator is not supported yet."
)
return self._convert_elemwise(_op.greater, op)
return self._convert_elemwise(_op.greater, op, self.is_quantized(op), comparison_op=True)

def convert_squared_difference(self, op):
"""Convert TFLite SQUARED DIFFERENCE"""
Expand Down Expand Up @@ -1475,7 +1476,7 @@ def convert_less_equal(self, op):

def convert_equal(self, op):
"""Convert TFLite EQUAL"""
return self._convert_elemwise(_op.equal, op, self.is_quantized(op))
return self._convert_elemwise(_op.equal, op, self.is_quantized(op), comparison_op=True)

def convert_not_equal(self, op):
"""Convert TFLite NOT_EQUAL"""
Expand Down
21 changes: 18 additions & 3 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,7 @@ def _test_elemwise(
quantized=False,
qnn_op=None,
same_qnn_params=False,
comparison_op=False,
):
"""One iteration of elemwise"""

Expand Down Expand Up @@ -2298,7 +2299,7 @@ def __test_elemwise(in_data):
if x[0] is not None
}

if math_op is math_ops.equal:
if comparison_op:
out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)

Expand All @@ -2307,13 +2308,17 @@ def __test_elemwise(in_data):
[x + ":0" for x in input_range.keys()],
[x[1] for x in zip(in_data, inq_data) if x[0] is not None],
[out],
quantized=True,
input_range=input_range,
experimental_new_converter=same_qnn_params,
)
else:
out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)
out = tf.quantization.fake_quant_with_min_max_args(
out, min=out_min, max=out_max, name="out"
)

# Note same_qnn_params uses experimental_new_converter as toco failed
compare_tflite_with_tvm(
[x[1] for x in zip(in_data, data) if x[0] is not None],
Expand Down Expand Up @@ -2440,9 +2445,17 @@ def _test_minimum(data, fused_activation_function=None, quantized=False, qnn_op=
# -------


def _test_greater(data):
def _test_greater(data, fused_activation_function=None, quantized=False, qnn_op=None):
"""One iteration of greater"""
return _test_elemwise(math_ops.greater, data)
return _test_elemwise(
math_ops.greater,
data,
fused_activation_function,
quantized,
qnn_op,
same_qnn_params=True,
comparison_op=True,
)


#######################################################################
Expand Down Expand Up @@ -2489,6 +2502,7 @@ def _test_equal(data, fused_activation_function=None, quantized=False, qnn_op=No
quantized,
qnn_op,
same_qnn_params=True,
comparison_op=True,
)


Expand Down Expand Up @@ -2615,6 +2629,7 @@ def test_all_elemwise():
_test_forward_elemwise(_test_minimum)
_test_forward_elemwise_quantized(_test_minimum)
_test_forward_elemwise(_test_greater)
_test_forward_elemwise_quantized(_test_greater)
_test_forward_elemwise(_test_squared_difference)
_test_forward_elemwise(_test_greater_equal)
_test_forward_elemwise(_test_less)
Expand Down

0 comments on commit 2c2c70f

Please sign in to comment.