Skip to content

Commit

Permalink
Add hack for legacy keras (onnx#1486)
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft authored and zerollzeng committed May 16, 2021
1 parent a7c99ad commit 99b8a8a
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,20 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_

# let tensorflow do the checking if model is a valid model
function = _saving_utils.trace_model_call(model, input_signature)
concrete_func = function.get_concrete_function()
try:
concrete_func = function.get_concrete_function()
except TypeError as e:
# Legacy keras models don't accept the training arg tf provides so we hack around it
if "got an unexpected keyword argument 'training'" not in str(e):
raise e
model_call = model.call
def wrap_call(*args, training=False, **kwargs):
return model_call(*args, **kwargs)
model.call = wrap_call
function = _saving_utils.trace_model_call(model, input_signature)
concrete_func = function.get_concrete_function()
# Put it back
model.call = model_call

# These inputs will be removed during freezing (includes resources, etc.)
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
Expand Down

0 comments on commit 99b8a8a

Please sign in to comment.