Skip to content

Commit

Permalink
[TFLite] Support quantized GREATER op in TFLite frontend
Browse files Browse the repository at this point in the history
Update elementwise quantized test for GREATER op
Change-Id: I9b20ab313f2ca984355a53aa289dd7c0f523f090
  • Loading branch information
dchauhan-arm committed Jun 20, 2022
1 parent 7f68773 commit 301e3ae
Showing 1 changed file with 50 additions and 24 deletions.
74 changes: 50 additions & 24 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2156,6 +2156,7 @@ def _test_elemwise(
quantized=False,
qnn_op=None,
same_qnn_params=False,
comparison_op=False,
):
"""One iteration of elemwise"""

Expand Down Expand Up @@ -2200,22 +2201,32 @@ def __test_elemwise(in_data):
if None != x[0]
}

out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)
out = tf.quantization.fake_quant_with_min_max_args(
out, min=out_min, max=out_max, name="out"
)
if comparison_op:
out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(
[x[1] for x in zip(in_data, data) if None != x[0]],
[x + ":0" for x in input_range.keys()],
[x[1] for x in zip(in_data, inq_data) if None != x[0]],
[out],
)
else:
out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)
out = tf.quantization.fake_quant_with_min_max_args(
out, min=out_min, max=out_max, name="out"
)

# Note same_qnn_params uses experimental_new_converter as toco failed
compare_tflite_with_tvm(
[x[1] for x in zip(in_data, data) if None != x[0]],
[x + ":0" for x in input_range.keys()],
[x[1] for x in zip(in_data, inq_data) if None != x[0]],
[out],
quantized=True,
input_range=input_range,
experimental_new_converter=same_qnn_params,
)
# Note same_qnn_params uses experimental_new_converter as toco failed
compare_tflite_with_tvm(
[x[1] for x in zip(in_data, data) if None != x[0]],
[x + ":0" for x in input_range.keys()],
[x[1] for x in zip(in_data, inq_data) if None != x[0]],
[out],
quantized=True,
input_range=input_range,
experimental_new_converter=same_qnn_params,
)
else:
out = math_op(
in_data[0]
Expand Down Expand Up @@ -2340,6 +2351,8 @@ def _test_greater(data, fused_activation_function=None, quantized=False, qnn_op=
fused_activation_function,
quantized,
qnn_op,
same_qnn_params=True,
comparison_op=True,
)


Expand Down Expand Up @@ -2445,15 +2458,26 @@ def _test_forward_elemwise(testop):
)


def _test_forward_elemwise_quantized(testop):
testop(
[
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
],
quantized=True,
qnn_op=testop,
)
def _test_forward_elemwise_quantized(testop, comparison_op=False):
if not comparison_op:
testop(
[
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
],
quantized=True,
qnn_op=testop,
)
else:
# no need for fake_quant to hold tensors in float32 until conversion
testop(
[
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.float32),
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.float32),
],
quantized=True,
qnn_op=testop,
)


def _test_elemwise_qnn_out_range(qnn_op):
Expand All @@ -2464,6 +2488,7 @@ def _test_elemwise_qnn_out_range(qnn_op):
_test_mul: (-5e3, 5e3),
_test_maximum: (-112, 111),
_test_minimum: (-128, 127),
_test_greater: (-150, 150),
}

return qnn_out_range[qnn_op]
Expand Down Expand Up @@ -2495,6 +2520,7 @@ def test_all_elemwise():
_test_forward_elemwise(_test_squared_difference)
_test_forward_elemwise(_test_greater_equal)
_test_forward_elemwise(_test_greater)
_test_forward_elemwise_quantized(_test_greater, True)
_test_forward_elemwise(_test_less)
_test_forward_elemwise(_test_less_equal)
_test_forward_elemwise(_test_equal)
Expand Down

0 comments on commit 301e3ae

Please sign in to comment.