diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 4d607e46c97f..8a81f47f8bcc 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -775,7 +775,7 @@ def convert_softmax(self, op): assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - params = {"axis": 1} # 1 is channel + params = {"axis": -1} # -1 is channel in_expr = self.get_expr(input_tensor_idx) # TODO - Naive softmax int8 implementation leads to bad accuracy. Currently, we can diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index f2941030f0ab..535459915ca3 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -3266,6 +3266,7 @@ def _test_softmax(data): def test_forward_softmax(): """Softmax""" _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 2, 3))) ######################################################################