Skip to content

Commit

Permalink
Fix usage of custom_ops, custom_op_handlers, and custom_rewriter ags (#…
Browse files Browse the repository at this point in the history
…1708)

Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft authored Sep 13, 2021
1 parent 5e48449 commit 1c60588
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def get_args():
parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
parser.add_argument("--dequantize", help="Remove quantization from model. Only supported for tflite currently.",
action="store_true")
parser.add_argument("--custom-ops", help="comma-separated map of custom ops to domains in format OpName:domain")
parser.add_argument("--custom-ops", help="Comma-separated map of custom ops to domains in format OpName:domain. "
"Domain 'ai.onnx.converters.tensorflow' is used by default.")
parser.add_argument("--extra_opset", default=None,
help="extra opset with format like domain:version, e.g. com.microsoft:1")
parser.add_argument("--load_op_libraries",
Expand Down Expand Up @@ -137,13 +138,19 @@ def default_custom_op_handler(ctx, node, name, args):


def _convert_common(frozen_graph, name="unknown", large_model=False, output_path=None,
output_frozen_graph=None, **kwargs):
output_frozen_graph=None, custom_ops=None, custom_op_handlers=None, **kwargs):
"""Common processing for conversion."""

model_proto = None
external_tensor_storage = None
const_node_values = None

if custom_ops is not None:
if custom_op_handlers is None:
custom_op_handlers = {}
custom_op_handlers.update(
{op: (make_default_custom_op_handler(domain), []) for op, domain in custom_ops.items()})

with tf.Graph().as_default() as tf_graph:
if large_model:
const_node_values = compress_graph_def(frozen_graph)
Expand All @@ -152,7 +159,8 @@ def _convert_common(frozen_graph, name="unknown", large_model=False, output_path
utils.save_protobuf(output_frozen_graph, frozen_graph)
if not kwargs.get("tflite_path") and not kwargs.get("tfjs_path"):
tf.import_graph_def(frozen_graph, name='')
g = process_tf_graph(tf_graph, const_node_values=const_node_values, **kwargs)
g = process_tf_graph(tf_graph, const_node_values=const_node_values,
custom_op_handlers=custom_op_handlers, **kwargs)
if constants.ENV_TF2ONNX_CATCH_ERRORS in os.environ:
catch_errors = constants.ENV_TF2ONNX_CATCH_ERRORS.upper() == "TRUE"
else:
Expand Down Expand Up @@ -180,7 +188,7 @@ def main():
extra_opset = args.extra_opset or []
tflite_path = None
tfjs_path = None
custom_ops = {}
custom_op_handlers = {}
initialized_tables = None
tensors_to_rename = {}
if args.custom_ops:
Expand All @@ -192,7 +200,7 @@ def main():
# default custom ops for tensorflow-onnx are in the "tf" namespace
using_tf_opset = True
domain = constants.TENSORFLOW_OPSET.domain
custom_ops[op] = (make_default_custom_op_handler(domain), [])
custom_op_handlers[op] = (make_default_custom_op_handler(domain), [])
if using_tf_opset:
extra_opset.append(constants.TENSORFLOW_OPSET)

Expand Down Expand Up @@ -259,7 +267,7 @@ def main():
continue_on_error=args.continue_on_error,
target=args.target,
opset=args.opset,
custom_op_handlers=custom_ops,
custom_op_handlers=custom_op_handlers,
extra_opset=extra_opset,
shape_override=args.shape_override,
input_names=inputs,
Expand Down Expand Up @@ -371,7 +379,9 @@ def _from_keras_tf1(model, input_signature=None, opset=None, custom_ops=None, cu
continue_on_error=True,
target=target,
opset=opset,
custom_op_handlers=custom_ops,
custom_ops=custom_ops,
custom_op_handlers=custom_op_handlers,
custom_rewriter=custom_rewriter,
extra_opset=extra_opset,
shape_override=shape_override,
input_names=input_names,
Expand Down Expand Up @@ -475,7 +485,9 @@ def wrap_call(*args, training=False, **kwargs):
continue_on_error=True,
target=target,
opset=opset,
custom_op_handlers=custom_ops,
custom_ops=custom_ops,
custom_op_handlers=custom_op_handlers,
custom_rewriter=custom_rewriter,
extra_opset=extra_opset,
shape_override=shape_override,
input_names=input_names,
Expand Down Expand Up @@ -537,7 +549,9 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c
continue_on_error=True,
target=target,
opset=opset,
custom_op_handlers=custom_ops,
custom_ops=custom_ops,
custom_op_handlers=custom_op_handlers,
custom_rewriter=custom_rewriter,
extra_opset=extra_opset,
shape_override=shape_override,
input_names=input_names,
Expand Down Expand Up @@ -599,7 +613,9 @@ def from_graph_def(graph_def, name=None, input_names=None, output_names=None, op
continue_on_error=True,
target=target,
opset=opset,
custom_op_handlers=custom_ops,
custom_ops=custom_ops,
custom_op_handlers=custom_op_handlers,
custom_rewriter=custom_rewriter,
extra_opset=extra_opset,
shape_override=shape_override,
input_names=input_names,
Expand Down

0 comments on commit 1c60588

Please sign in to comment.