diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 67d2c383cf4ea..528fabc1787f2 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2215,7 +2215,7 @@ def __test_elemwise(in_data): if None != x[0] } - if math_op is math_ops.equal: + if comparison_op: out = math_op(inq_data[0], inq_data[1]) out = with_fused_activation_function(out, fused_activation_function) @@ -2415,6 +2415,7 @@ def _test_equal(data, fused_activation_function=None, quantized=False, qnn_op=No quantized, qnn_op, same_qnn_params=True, + comparison_op=True, ) @@ -2480,8 +2481,8 @@ def _test_forward_elemwise(testop): ) -def _test_forward_elemwise_quantized(testop): - if testop is not _test_equal: +def _test_forward_elemwise_quantized(testop, comparison_op=False): + if not comparison_op: testop( [ np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8), @@ -2511,6 +2512,7 @@ def _test_elemwise_qnn_out_range(qnn_op): _test_maximum: (-112, 111), _test_minimum: (-128, 127), _test_equal: (-150, 150), + _test_greater: (-150, 150), } return qnn_out_range[qnn_op] @@ -2546,7 +2548,7 @@ def test_all_elemwise(): _test_forward_elemwise(_test_less) _test_forward_elemwise(_test_less_equal) _test_forward_elemwise(_test_equal) - _test_forward_elemwise_quantized(_test_equal) + _test_forward_elemwise_quantized(_test_equal, True) _test_forward_elemwise(_test_not_equal) if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"): _test_forward_elemwise(_test_floor_divide)