Skip to content

Commit

Permalink
[FRONTEND][TFLITE]Logical not op support (#5475)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored Apr 30, 2020
1 parent 90b08f5 commit 095f565
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
11 changes: 11 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(self, model, subgraph, exp_tab):
'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn,
'LOG': self.convert_log,
'LOGICAL_AND': self.convert_logical_and,
'LOGICAL_NOT': self.convert_logical_not,
'LOGICAL_OR': self.convert_logical_or,
'LOGISTIC': self.convert_logistic,
'MAX_POOL_2D': self.convert_max_pool2d,
Expand Down Expand Up @@ -992,6 +993,16 @@ def convert_logical_or(self, op):
"""Convert tflite LOGICAL_OR"""
return self._convert_logical_binary(_op.logical_or, op)

def convert_logical_not(self, op):
"""Convert tflite LOGICAL_NOT"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

data = self.get_expr(input_tensors[0].tensor_idx)
out = _op.logical_not(data)

return out

def convert_gather(self, op):
"""Method to Convert TFLite GATHER operator"""
try:
Expand Down
12 changes: 11 additions & 1 deletion tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,12 @@ def _test_logical_binary(logical_bin_op, data):
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')]
out = logical_bin_op(in_data[0], in_data[1], name='out')
if logical_bin_op == math_ops.logical_not:
out = math_ops.logical_or(in_data[0], in_data[1], name='out1')
out = logical_bin_op(out, name='out')
else:
out = logical_bin_op(in_data[0], in_data[1], name='out')

compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])

def _test_forward_logical_and(data):
Expand All @@ -1194,13 +1199,18 @@ def _test_forward_logical_or(data):
""" One iteration of logical or """
return _test_logical_binary(math_ops.logical_or, data)

def _test_forward_logical_not(data):
""" One iteration of logical not """
return _test_logical_binary(math_ops.logical_not, data)

def test_all_logical():
data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'),
np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')]
# boolean dtype is not supported by older versions than TFLite 1.15.0
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
_test_forward_logical_and(data)
_test_forward_logical_or(data)
_test_forward_logical_not(data)

#######################################################################
# Zeros like
Expand Down

0 comments on commit 095f565

Please sign in to comment.