diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 8c2f73b0a..4d7333c50 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -316,7 +316,20 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_ # let tensorflow do the checking if model is a valid model function = _saving_utils.trace_model_call(model, input_signature) - concrete_func = function.get_concrete_function() + try: + concrete_func = function.get_concrete_function() + except TypeError as e: + # Legacy keras models don't accept the training arg tf provides so we hack around it + if "got an unexpected keyword argument 'training'" not in str(e): + raise e + model_call = model.call + def wrap_call(*args, training=False, **kwargs): + return model_call(*args, **kwargs) + model.call = wrap_call + function = _saving_utils.trace_model_call(model, input_signature) + concrete_func = function.get_concrete_function() + # Put it back + model.call = model_call # These inputs will be removed during freezing (includes resources, etc.) graph_captures = concrete_func.graph._captures # pylint: disable=protected-access