From 51839008c35503d8a7745695b10ae2d5d092954b Mon Sep 17 00:00:00 2001 From: Ophir Frish Date: Thu, 23 Sep 2021 18:22:53 +0300 Subject: [PATCH] Support quantised RSQRT operator in TFLite The commit tests _convert_unary_elemwise function for the quantised and non quantized tensor for the RSQRT op. Other operators will be tested in future (separated )commits. --- python/tvm/relay/frontend/tflite.py | 14 ++++-- tests/python/frontend/tflite/test_forward.py | 53 +++++++++++++++----- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 93a1dba233f2..a66fc4736a98 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1116,8 +1116,16 @@ def _convert_unary_elemwise(self, relay_op, op): input_tensor = input_tensors[0] in_expr = self.get_expr(input_tensor.tensor_idx) - out = relay_op(in_expr) + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + if input_tensor.qnn_params: + in_expr = self.dequantize(in_expr, input_tensor) + out = relay_op(in_expr) + if output_tensor.qnn_params: + out = self.quantize(out, output_tensor) return out def convert_abs(self, op): @@ -1186,10 +1194,6 @@ def convert_sqrt(self, op): def convert_rsqrt(self, op): """Convert TFLite RSQRT""" - if self.is_quantized(op): - raise tvm.error.OpNotImplemented( - "TFlite quantized RSQRT operator is not supported yet." - ) return self._convert_unary_elemwise(_op.rsqrt, op) def convert_neg(self, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index c073681dcbf5..4a6f88417b9c 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1868,16 +1868,6 @@ def _test_sqrt(data): return _test_unary_elemwise(math_ops.sqrt, data) -####################################################################### -# Rsqrt -# ----- - - -def _test_rsqrt(data): - """One iteration of rsqrt""" - return _test_unary_elemwise(math_ops.rsqrt, data) - - ####################################################################### # Neg # --- @@ -1910,7 +1900,7 @@ def _test_elu(data): def _test_forward_unary_elemwise(test_op): # functions that need positive input - if test_op.__name__ in {"_test_log", "_test_sqrt", "_test_rsqrt"}: + if test_op.__name__ in {"_test_log", "_test_sqrt"}: test_op(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))) else: test_op(np.random.uniform(-10, 10, (3, 2)).astype(np.float32)) @@ -1923,7 +1913,6 @@ def test_all_unary_elemwise(): _test_forward_unary_elemwise(_test_log) _test_forward_unary_elemwise(_test_sin) _test_forward_unary_elemwise(_test_sqrt) - _test_forward_unary_elemwise(_test_rsqrt) _test_forward_unary_elemwise(_test_neg) _test_forward_unary_elemwise(_test_square) # ceil and cos come with TFLite 1.14.0.post1 fbs schema @@ -3352,6 +3341,45 @@ def test_forward_tanh(): _test_tanh(np.arange(0, 256, 30, dtype=np.uint8), quantized=True) +####################################################################### +# RSQRT +# ---- + + +def _test_rsqrt(data, quantized=False): + """One iteration of RSQRT""" + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0") + + if quantized: + inq_data = tf.quantization.fake_quant_with_min_max_args( + in_data, min=1, max=6, name="inq_0" + ) + input_range = {"inq_0": (1, 6)} + out = math_ops.rsqrt(inq_data) + out = tf.quantization.fake_quant_with_min_max_args(out, min=1, max=6, name="out") + compare_tflite_with_tvm( + data, + "inq_0:0", + [inq_data], + [out], + quantized=True, + input_range=input_range, + experimental_new_converter=True, + ) + else: + out = math_ops.rsqrt(in_data) + compare_tflite_with_tvm(data, "in_0:0", [in_data], [out]) + + +def test_forward_rsqrt(): + """RSQRT""" + _test_rsqrt(np.arange(1.0, 7.0, dtype=np.float32), quantized=False) + _test_rsqrt(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)), quantized=False) + _test_rsqrt(np.arange(1, 240, 40, dtype=np.uint8), quantized=True) + _test_rsqrt(np.arange(1, 240, 40, dtype=np.uint8).reshape((2, 1, 3)), quantized=True) + + ####################################################################### # ReLu # ---- @@ -4561,6 +4589,7 @@ def test_prevent_tensorflow_dynamic_range(): test_forward_l2_pool2d() test_forward_softmax() test_forward_tanh() + test_forward_rsqrt() test_forward_relu() test_forward_relu6() test_forward_leaky_relu()