Skip to content

Commit

Permalink
[TFLITE]Select/Where op support for tflite frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Apr 30, 2020
1 parent 095f565 commit af31d73
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
40 changes: 27 additions & 13 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self, model, subgraph, exp_tab):
'LOGISTIC': self.convert_logistic,
'MAX_POOL_2D': self.convert_max_pool2d,
'MAXIMUM': self.convert_maximum,
'MEAN': self._convert_reduce_mean,
'MEAN': self.convert_reduce_mean,
'MINIMUM': self.convert_minimum,
'MIRROR_PAD': self.convert_mirror_pad,
'MUL': self.convert_mul,
Expand All @@ -109,16 +109,17 @@ def __init__(self, model, subgraph, exp_tab):
'PAD': self.convert_pad,
'POW': self.convert_pow,
'PRELU': self.convert_prelu,
'REDUCE_ANY': self._convert_reduce_any,
'REDUCE_MAX': self._convert_reduce_max,
'REDUCE_MIN': self._convert_reduce_min,
'REDUCE_PROD': self._convert_reduce_prod,
'REDUCE_ANY': self.convert_reduce_any,
'REDUCE_MAX': self.convert_reduce_max,
'REDUCE_MIN': self.convert_reduce_min,
'REDUCE_PROD': self.convert_reduce_prod,
'RELU':self.convert_relu,
'RESHAPE': self.convert_reshape,
'RESIZE_BILINEAR': self.convert_resize_bilinear,
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
'ROUND': self.convert_round,
'RSQRT': self.convert_rsqrt,
'SELECT': self.convert_select,
'SIN': self.convert_sin,
'SLICE': self.convert_slice,
'SOFTMAX': self.convert_softmax,
Expand All @@ -132,14 +133,15 @@ def __init__(self, model, subgraph, exp_tab):
'SQUEEZE': self.convert_squeeze,
'STRIDED_SLICE': self.convert_strided_slice,
'SUB': self.convert_sub,
'SUM': self._convert_reduce_sum,
'SUM': self.convert_reduce_sum,
'TAN': self.convert_tan,
'TANH':self.convert_tanh,
'TILE': self.convert_tile,
'TOPK_V2': self.convert_topk_v2,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'TRANSPOSE': self.convert_transpose,
'UNPACK': self.convert_unpack,
'WHERE': self.convert_select,
'ZEROS_LIKE': self.convert_zeros_like,
}

Expand Down Expand Up @@ -1241,7 +1243,7 @@ def convert_fill(self, op):
return out

def _convert_reduce(self, relay_op, op):
"""Generic method to Convert TFLite MEAN operators"""
"""Generic method to Convert TFLite REDUCE operators"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ReducerOptions import ReducerOptions
Expand Down Expand Up @@ -1285,22 +1287,22 @@ def _convert_reduce(self, relay_op, op):

return out

def _convert_reduce_min(self, op):
def convert_reduce_min(self, op):
return self._convert_reduce(_op.reduce.min, op)

def _convert_reduce_max(self, op):
def convert_reduce_max(self, op):
return self._convert_reduce(_op.reduce.max, op)

def _convert_reduce_mean(self, op):
def convert_reduce_mean(self, op):
return self._convert_reduce(_op.reduce.mean, op)

def _convert_reduce_prod(self, op):
def convert_reduce_prod(self, op):
return self._convert_reduce(_op.reduce.prod, op)

def _convert_reduce_sum(self, op):
def convert_reduce_sum(self, op):
return self._convert_reduce(_op.reduce.sum, op)

def _convert_reduce_any(self, op):
def convert_reduce_any(self, op):
return self._convert_reduce(_op.reduce.any, op)

def convert_fully_connected(self, op):
Expand Down Expand Up @@ -1697,6 +1699,18 @@ def convert_slice(self, op):

return out

def convert_select(self, op):
"""Convert TFLite SELECT"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be == 3"
cond = self.get_expr(input_tensors[0].tensor_idx)
x = self.get_expr(input_tensors[1].tensor_idx)
y = self.get_expr(input_tensors[2].tensor_idx)

out = _op.where(cond, x, y)

return out

def convert_transpose(self, op):
"""transpose implementation."""
input_tensors = self.get_input_tensors(op)
Expand Down
22 changes: 22 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,27 @@ def test_all_reduce():


#######################################################################
# Select, Where
# -------------

def test_forward_select():
with tf.Graph().as_default():
with tf.Session() as sess:
input1 = tf.placeholder(
tf.int32, shape=[1, 4, 4, 3], name='input1')
input2 = tf.placeholder(
tf.int32, shape=[1, 4, 4, 3], name='input2')
mask = input1 > input2
out = tf.where(mask, input1 + 1, input2 * 2)
in_data1 = np.random.uniform(
0, 10, size=(1, 4, 4, 3)).astype("int32")
in_data2 = np.random.uniform(
0, 10, size=(1, 4, 4, 3)).astype("int32")

compare_tflite_with_tvm([in_data1, in_data2], [
'input1:0', 'input2:0'], [input1, input2], [out])


# Squeeze
# -------

Expand Down Expand Up @@ -2014,6 +2035,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_select()

# NN
test_forward_convolution()
Expand Down

0 comments on commit af31d73

Please sign in to comment.