From d7f42cf004096e4d85f804166354f7510ab21047 Mon Sep 17 00:00:00 2001 From: sunway Date: Fri, 24 Sep 2021 11:33:35 +0800 Subject: [PATCH] [Frontend][TFLite] fix #9078 --- python/tvm/relay/frontend/tflite.py | 2 +- tests/python/frontend/tflite/test_forward.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 4d607e46c97f..8a81f47f8bcc 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -775,7 +775,7 @@ def convert_softmax(self, op): assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - params = {"axis": 1} # 1 is channel + params = {"axis": -1} # -1 is channel in_expr = self.get_expr(input_tensor_idx) # TODO - Naive softmax int8 implementation leads to bad accuracy. Currently, we can diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index f2941030f0ab..535459915ca3 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -3266,6 +3266,7 @@ def _test_softmax(data): def test_forward_softmax(): """Softmax""" _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 2, 3))) ######################################################################