From 7f687731fb6dd70816b876cc885368614ee1899d Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Tue, 31 May 2022 13:30:51 +0000 Subject: [PATCH] [TFLite] Support quantized GREATER op in TFLite frontend Support GREATER quantization operation conversion as part of issue #9187 --- python/tvm/relay/frontend/tflite.py | 7 +------ tests/python/frontend/tflite/test_forward.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 342c4e2ae553a..a2555cc4d78d0 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1399,12 +1399,7 @@ def convert_minimum(self, op): def convert_greater(self, op): """Convert TFLite GREATER""" - # Check if the input tensor is quantized, call QNN op - if self.is_quantized(op): - raise tvm.error.OpNotImplemented( - "TFlite quantized GREATER operator is not supported yet." - ) - return self._convert_elemwise(_op.greater, op) + return self._convert_elemwise(_op.greater, op, self.is_quantized(op)) def convert_squared_difference(self, op): """Convert TFLite SQUARED DIFFERENCE""" diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 8b0244d75eda8..8cd9acb4422ae 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2332,9 +2332,15 @@ def _test_minimum(data, fused_activation_function=None, quantized=False, qnn_op= # ------- -def _test_greater(data): +def _test_greater(data, fused_activation_function=None, quantized=False, qnn_op=None): """One iteration of greater""" - return _test_elemwise(math_ops.greater, data) + return _test_elemwise( + math_ops.greater, + data, + fused_activation_function, + quantized, + qnn_op, + ) ####################################################################### @@ -2418,7 +2424,7 @@ def _test_floor_mod(data): def _test_forward_elemwise(testop): - """Elewise""" + """Elemwise""" testop( [ np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), @@ -2486,9 +2492,9 @@ def test_all_elemwise(): _test_forward_elemwise_quantized(_test_maximum) _test_forward_elemwise(_test_minimum) _test_forward_elemwise_quantized(_test_minimum) - _test_forward_elemwise(_test_greater) _test_forward_elemwise(_test_squared_difference) _test_forward_elemwise(_test_greater_equal) + _test_forward_elemwise(_test_greater) _test_forward_elemwise(_test_less) _test_forward_elemwise(_test_less_equal) _test_forward_elemwise(_test_equal)