Skip to content

Commit

Permalink
[TFLITE][FRONTEND]Reduce_any op parsing support
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Feb 24, 2020
1 parent c4c61cb commit e32831f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
4 changes: 4 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self, model, subgraph, exp_tab):
'EQUAL': self.convert_equal,
'NOT_EQUAL': self.convert_not_equal,
'ZEROS_LIKE': self.convert_zeros_like,
'REDUCE_ANY': self._convert_reduce_any,
'REDUCE_MIN': self._convert_reduce_min,
'REDUCE_MAX': self._convert_reduce_max,
'MEAN': self._convert_reduce_mean,
Expand Down Expand Up @@ -935,6 +936,9 @@ def _convert_reduce_prod(self, op):
def _convert_reduce_sum(self, op):
return self._convert_reduce(_op.reduce.sum, op)

def _convert_reduce_any(self, op):
return self._convert_reduce(_op.reduce.any, op)

def convert_fully_connected(self, op):
"""Convert TFLite fully connected"""
try:
Expand Down
15 changes: 12 additions & 3 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,11 +1099,19 @@ def _test_reduce_sum(data, keep_dims=None):
""" One iteration of reduce_sum """
return _test_reduce(math_ops.reduce_sum, data, keep_dims)

#######################################################################
# Reduce_any
# -----------

def _test_reduce_any(data, keep_dims=None):
""" One iteration of reduce_any """
return _test_reduce(math_ops.reduce_any, data, keep_dims)


def _test_forward_reduce(testop):
def _test_forward_reduce(testop, dtype="float32"):
""" Reduce """
data0 = [np.random.rand(16, 16, 16, 16).astype("float32"), None]
data1 = [np.random.rand(16, 16, 16, 16).astype("float32"), np.array([1, 2], dtype=np.int32)]
data0 = [np.random.rand(16, 16, 16, 16).astype(dtype), None]
data1 = [np.random.rand(16, 16, 16, 16).astype(dtype), np.array([1, 2], dtype=np.int32)]
testop(data0)
testop(data0, keep_dims=False)
testop(data0, keep_dims=True)
Expand All @@ -1124,6 +1132,7 @@ def test_all_reduce():
_test_forward_reduce_quantized(_test_reduce_mean)
_test_forward_reduce(_test_reduce_prod)
_test_forward_reduce(_test_reduce_sum)
_test_forward_reduce(_test_reduce_any, dtype="bool")


#######################################################################
Expand Down

0 comments on commit e32831f

Please sign in to comment.