Skip to content

Commit

Permalink
Allow converting keras.layers.Sequential
Browse files Browse the repository at this point in the history
  • Loading branch information
nhynes committed Mar 18, 2019
1 parent 011f0b6 commit 1b6caf4
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,12 +653,15 @@ def from_keras(model, shape=None):
raise ValueError("Keras frontend currently supports data_format = channels_last only.")
_check_unsupported_layers(model)

def _convert_input_layer(keras_layer):
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, _expr.var(input_name, shape=input_shape))

etab = ExprTable()
for keras_layer in model.layers:
if isinstance(keras_layer, keras.engine.InputLayer):
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, _expr.var(input_name, shape=input_shape))
_convert_input_layer(keras_layer)
else:
inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \
Expand All @@ -682,6 +685,7 @@ def from_keras(model, shape=None):
for n_idx, t_idx, inbound_layer in zip_node:
if isinstance(inbound_layer, keras.engine.InputLayer):
expr_name = inbound_layer.name
_convert_input_layer(inbound_layer)
else:
expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx)
expr = etab.get_expr(expr_name)
Expand Down

0 comments on commit 1b6caf4

Please sign in to comment.