Skip to content

Commit

Permalink
[FRONTEND][TFLite] Fully connected op conversion made in sync with T…
Browse files Browse the repository at this point in the history
…FLite (#5510)

* [FRONTEND][TFLite] Fully connected op conversion made in sync with TFLite

* [1] Test case added

* [2] Review comments handled

* [3] Prints removed
  • Loading branch information
ANSHUMAN TRIPATHY authored May 7, 2020
1 parent 149965a commit b3730e5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
33 changes: 24 additions & 9 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,16 +1331,28 @@ def convert_fully_connected(self, op):
input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy()

# reshape input tensor from N H W C to N H*W*C
input_size_per_batch = 1
for s in range(1, len(input_tensor_shape)):
input_size_per_batch *= input_tensor_shape[s]
assert input_size_per_batch == weight_tensor_shape[1], \
"input size and weight size are mismatched"
target_shape = tuple((input_tensor_shape[0], input_size_per_batch))
# Weight should have only 2 dimensions(TFLite convention)
assert len(weight_tensor_shape) == 2, "Weight should be only 2-dim"

# Input shape: [i_batch_size, ..., n_inputs]
# Filter shape: [n_inputs, n_units]
#
# As we will transform Fully_Connected Input to Dense Op inputs as below
# Dense expected Input shape: [batch_size, n_units]
# Dense expected Weight shape: [out_dim, n_units]
# Dense output shape: [batch_size, out_dim]
# So it is evident that input shape: [batch_size = input_size / n_units, n_units]
input_size = 1
for _, shape in enumerate(input_tensor_shape):
input_size *= shape

# First get the batch size
batch_size = int(input_size / weight_tensor_shape[1])
target_shape = tuple((batch_size, weight_tensor_shape[1]))
in_expr = self.get_expr(input_tensor_idx)
in_expr = _op.reshape(in_expr, target_shape)

#TODO: Change the output shape calculation based on keep_dim option
assert op.BuiltinOptionsType() == BuiltinOptions.FullyConnectedOptions
op_options = op.BuiltinOptions()
fully_connected_options = FullyConnectedOptions()
Expand All @@ -1352,8 +1364,11 @@ def convert_fully_connected(self, op):
assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

weight_value = self.get_tensor_value(weight_tensor)
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
if self.has_expr(weight_tensor.tensor_idx):
weight_expr = self.get_expr(weight_tensor.tensor_idx)
else:
weight_value = self.get_tensor_value(weight_tensor)
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
weight_shape = _infer_shape(weight_expr)

if input_tensor.qnn_params:
Expand Down
26 changes: 26 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,29 @@ def test_forward_cast():
_test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8)
_test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64)

#######################################################################
# Batch Mat Mul
# ----
def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False):
with tf.Graph().as_default():
A = array_ops.placeholder(shape=A_shape, dtype=dtype, name='A')
B = array_ops.placeholder(shape=B_shape, dtype=dtype, name='B')
result = math_ops.matmul(A, B, adjoint_a=adjoint_a,
adjoint_b=adjoint_b, name='batchmatmul')

A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
compare_tflite_with_tvm([A_np, B_np], [A.name, B.name], [A, B], [result])


def test_forward_batch_matmul():
""" BATCH_MAT_MUL """
_test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32')
_test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True)
_test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', True, False)
_test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True)
_test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'float32')

#######################################################################
# Tile
# ----
Expand Down Expand Up @@ -2001,6 +2024,9 @@ def test_forward_mediapipe_hand_landmark():
# Cast
test_forward_cast()

# BatchMatMul
test_forward_batch_matmul()

# Tile
test_forward_tile()

Expand Down

0 comments on commit b3730e5

Please sign in to comment.