diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index ef9f63f3cd95..2871b7f73163 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -255,7 +255,8 @@ def get_expr(self, name): def set_expr(self, name, expr): assert isinstance(expr, _expr.Expr) - self.exprs[name] = expr + if name not in self.exprs: + self.exprs[name] = expr def set_padding(self, paddings): self.paddings = paddings diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index f6f2d99e2ea5..a865f08243eb 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -7,7 +7,7 @@ from .. import expr as _expr from .. import op as _op from ... import nd as _nd -from .common import ExprTable +from .common import ExprTable, new_var __all__ = ['from_keras'] @@ -661,12 +661,15 @@ def from_keras(model, shape=None): raise ValueError("Keras frontend currently supports data_format = channels_last only.") _check_unsupported_layers(model) + def _convert_input_layer(keras_layer): + input_name = keras_layer.name + input_shape = shape[input_name] if shape is not None and input_name in shape else None + etab.set_expr(input_name, new_var(input_name, shape=input_shape)) + etab = ExprTable() for keras_layer in model.layers: if isinstance(keras_layer, keras.engine.InputLayer): - input_name = keras_layer.name - input_shape = shape[input_name] if shape is not None and input_name in shape else None - etab.set_expr(input_name, _expr.var(input_name, shape=input_shape)) + _convert_input_layer(keras_layer) else: inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \ else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \ @@ -690,6 +693,7 @@ def from_keras(model, shape=None): for n_idx, t_idx, inbound_layer in zip_node: if isinstance(inbound_layer, keras.engine.InputLayer): expr_name = inbound_layer.name + _convert_input_layer(inbound_layer) else: expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx) expr = etab.get_expr(expr_name) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index baa2e4fc203f..90c07ac09042 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -106,6 +106,17 @@ def test_forward_dense(): verify_keras_frontend(keras_model) +def test_forward_sequential(): + keras_model = keras.models.Sequential([ + keras.layers.Dense(16, input_dim=32, activation='relu'), + keras.layers.Dropout(0.5), + keras.layers.Dense(8, activation='relu'), + keras.layers.Dropout(0.5), + keras.layers.Dense(1, activation='sigmoid') + ]) + verify_keras_frontend(keras_model) + + def test_forward_pool(): data = keras.layers.Input(shape=(32,32,1)) # maxpool @@ -244,6 +255,7 @@ def test_forward_mobilenet(): test_forward_merge() test_forward_activations() test_forward_dense() + test_forward_sequential() test_forward_pool() test_forward_conv() test_forward_upsample(interpolation='nearest')