diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 36221b7467aa..f50131df25a7 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1656,7 +1656,6 @@ def convert_fully_connected(self, op): output_tensor_type = output_tensor.tensor.Type() output_tensor_type_str = self.get_tensor_type_str(output_tensor_type) - input_tensor_shape = input_tensor.tensor.ShapeAsNumpy() weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy() # Weight should have only 2 dimensions(TFLite convention) @@ -1669,14 +1668,7 @@ def convert_fully_connected(self, op): # Dense expected Input shape: [batch_size, n_units] # Dense expected Weight shape: [out_dim, n_units] # Dense output shape: [batch_size, out_dim] - # So it is evident that input shape: [batch_size = input_size / n_units, n_units] - input_size = 1 - for _, shape in enumerate(input_tensor_shape): - input_size *= shape - - # First get the batch size - batch_size = int(input_size / weight_tensor_shape[1]) - target_shape = tuple((batch_size, weight_tensor_shape[1])) + target_shape = tuple((-1, weight_tensor_shape[1])) in_expr = self.get_expr(input_tensor_idx) in_expr = _op.reshape(in_expr, target_shape) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 52491b2de308..a27e3d394980 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -210,11 +210,13 @@ def run_tflite_graph(tflite_model_buf, input_data): input_data = convert_to_list(input_data) interpreter = interpreter_wrapper.Interpreter(model_content=tflite_model_buf) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() + for i in range(len(input_details)): + interpreter.resize_tensor_input(input_details[i]['index'], input_data[i].shape) + interpreter.allocate_tensors() + # set input assert len(input_data) == len(input_details) for i in range(len(input_details)): @@ -2515,6 +2517,20 @@ def test_forward_inception_v4_net(): tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) +def test_forward_inception_v4_net_batched(): + """Test the Inception V4 TF Lite model.""" + # InceptionV4 + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz", + "inception_v4.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data = np.random.uniform(size=(4, 299, 299, 3)).astype('float32') + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-5, atol=1e-5) + def test_forward_qnn_inception_v1_net(): """Test the Quantized TFLite Inception model.""" # InceptionV1 @@ -2880,6 +2896,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_mobilenet_v3() test_forward_inception_v3_net() test_forward_inception_v4_net() + test_forward_inception_v4_net_batched() test_forward_coco_ssd_mobilenet_v1() test_forward_mediapipe_hand_landmark()