diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 75ccec52f309..3dc055108416 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -955,10 +955,13 @@ def _convert_concat( if input_shape is None: input_shape = keras_layer.input_shape - if data_layout == "NHWC" or len(input_shape[0]) < 4: - axis = -1 - else: - axis = 1 + axis = keras_layer.axis + dims = len(input_shape[0]) + if data_layout == "NCHW": # need_transpose + if axis == -1: + axis = 1 + else: + axis = axis + 1 if axis < dims else 1 return _op.concatenate(_as_list(inexpr), axis=axis) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 842d803b174d..2584a36e32e9 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -159,6 +159,24 @@ def test_forward_merge(self, keras_mod): keras_model = keras_mod.models.Model(data, out) verify_keras_frontend(keras_model) + def test_forward_concatenate(self, keras_mod): + """test_forward_concatenate""" + data1 = keras_mod.layers.Input(shape=(1, 2, 2)) + data2 = keras_mod.layers.Input(shape=(1, 1, 2)) + merge_func = keras_mod.layers.Concatenate(axis=2) + out = merge_func([data1, data2]) + keras_model = keras_mod.models.Model([data1, data2], out) + verify_keras_frontend(keras_model, layout="NHWC") + verify_keras_frontend(keras_model, layout="NCHW") + # test default axis (e.g., -1) + data1 = keras_mod.layers.Input(shape=(1, 2, 2)) + data2 = keras_mod.layers.Input(shape=(1, 2, 3)) + merge_func = keras_mod.layers.Concatenate() + out = merge_func([data1, data2]) + keras_model = keras_mod.models.Model([data1, data2], out) + verify_keras_frontend(keras_model, layout="NHWC") + verify_keras_frontend(keras_model, layout="NCHW") + def test_forward_merge_dot(self, keras_mod): """test_forward_merge_dot""" data1 = keras_mod.layers.Input(shape=(2, 2)) @@ -793,6 +811,7 @@ def test_forward_time_distributed(self, keras_mod): if __name__ == "__main__": for k in [keras, tf_keras]: sut = TestKeras() + sut.test_forward_concatenate(keras_mod=k) sut.test_forward_merge_dot(keras_mod=k) sut.test_forward_merge(keras_mod=k) sut.test_forward_activations(keras_mod=k)