diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py index a4d43cd43709..0d51487c355d 100644 --- a/nnvm/python/nnvm/frontend/keras.py +++ b/nnvm/python/nnvm/frontend/keras.py @@ -40,7 +40,7 @@ def _convert_activation(insym, keras_layer, _): return _sym.__add_scalar__(_sym.__mul_scalar__(insym, \ scalar=alpha), scalar=beta) elif act_type == 'softmax': - return _sym.softmax(insym) + return _sym.softmax(insym, axis=1) elif act_type == 'sigmoid': return _sym.sigmoid(insym) elif act_type == 'tanh': diff --git a/nnvm/tests/python/frontend/keras/test_forward.py b/nnvm/tests/python/frontend/keras/test_forward.py index 58a3d8c12ff6..0147a3e2c654 100644 --- a/nnvm/tests/python/frontend/keras/test_forward.py +++ b/nnvm/tests/python/frontend/keras/test_forward.py @@ -59,6 +59,15 @@ def test_forward_elemwise_add(): verify_keras_frontend(keras_model) +def test_forward_softmax(): + data = keras.layers.Input(shape=(32,32,3)) + x = keras.layers.Activation('softmax')(data) + x = keras.layers.Concatenate()([x, x]) + x = keras.layers.GlobalMaxPooling2D()(x) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) + + def test_forward_softrelu(): data = keras.layers.Input(shape=(32,32,3)) x = keras.layers.Activation('softplus')(data) @@ -145,6 +154,7 @@ def test_forward_resnet50(): if __name__ == '__main__': test_forward_elemwise_add() + test_forward_softmax() test_forward_softrelu() test_forward_leaky_relu() test_forward_dense()