From f8acb23ad3ef17e517f40a803961ff035145655a Mon Sep 17 00:00:00 2001 From: Dhruva Ray Date: Mon, 4 May 2020 19:45:12 +0530 Subject: [PATCH] [TFLITE]GATHER_ND Signed-off-by: Dhruva Ray --- python/tvm/relay/frontend/tflite.py | 26 ++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 31 ++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 5a645c67cf61..cb10ce5ee924 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -86,6 +86,7 @@ def __init__(self, model, subgraph, exp_tab): 'FLOOR': self.convert_floor, 'FULLY_CONNECTED': self.convert_fully_connected, 'GATHER': self.convert_gather, + 'GATHER_ND' : self.convert_gather_nd, 'GREATER_EQUAL': self.convert_greater_equal, 'GREATER': self.convert_greater, 'HARD_SWISH': self.convert_hard_swish, @@ -1113,6 +1114,31 @@ def convert_gather(self, op): out = _op.take(data, indices, axis=axis, mode="fast") return out + def convert_gather_nd(self, op): + """Method to Convert TFLite GATHER_ND operator""" + try: + from tflite.TensorType import TensorType + except ImportError: + raise ImportError("The tflite package must be installed") + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + for t in input_tensors: + assert not t.qnn_params, "Quantized input is not expected." + + data = self.get_tensor_expr(input_tensors[0]) + indices = self.get_tensor_expr(input_tensors[1]) + + indices_type = input_tensors[1].tensor.Type() + assert indices_type in (TensorType.INT32, TensorType.INT64) + + indices_dims = len(_infer_shape(indices)) + indices_t = _op.transpose(indices, axes=[-1] + list(range(indices_dims-1))) + + out = _op.gather_nd(data, indices_t) + return out + def convert_strided_slice(self, op): """Method to Convert TFLite STRIDED_SLICE operator. NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index da89a139c113..15b762537f6d 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -354,6 +354,36 @@ def test_forward_gather(): _test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True) _test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True) +####################################################################### +# Gather_ND +# --------- + +def _test_gather_nd(data, indices): + """ One iteration of GATHER_ND """ + with tf.Graph().as_default(): + in_data = tf.placeholder(shape=data.shape, dtype=data.dtype, name="data") + indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype, + name="indices") + out = tf.gather_nd(in_data, indices_data) + + compare_tflite_with_tvm([data, indices], ['data:0', 'indices:0'], + [in_data, indices_data], [out]) + +def test_forward_gather_nd(): + """ GATHER_ND """ + _test_gather_nd( + np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 8.1]]]).astype('float32'), + np.asarray([[0, 1], [1, 0]]).astype('int32') + ) + _test_gather_nd( + np.reshape(np.arange(30), [5, 6]).astype('int32'), + np.asarray([[1, 2]]).astype('int32') + ) + _test_gather_nd( + np.reshape(np.arange(12), [2, 3, 2]).astype('int32'), + np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32') + ) + ####################################################################### # StridedSlice # ------------ @@ -2191,6 +2221,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_slice() test_forward_topk() test_forward_gather() + test_forward_gather_nd() test_forward_stridedslice() test_forward_depthtospace() test_forward_spacetodepth()