Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf gain in tf_utils.py, more efficient error messages, faster comparisons #1054

Merged
merged 1 commit into from
Aug 12, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions tf2onnx/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf

from tensorflow.core.framework import types_pb2, tensor_pb2
Expand Down Expand Up @@ -70,7 +69,7 @@ def get_tf_tensor_data(tensor):
"""Get data from tensor."""
make_sure(isinstance(tensor, tensor_pb2.TensorProto), "Require TensorProto")
np_data = tensor_util.MakeNdarray(tensor)
make_sure(isinstance(np_data, np.ndarray), "{} isn't ndarray".format(np_data))
make_sure(isinstance(np_data, np.ndarray), "%r isn't ndarray", np_data)
return np_data


Expand All @@ -83,7 +82,7 @@ def get_tf_const_value(op, as_list=True):
when as_list=False, return np.array(1), type is <class 'numpy.ndarray'>
when as_list=True, return 1, type is <class 'int'>.
"""
make_sure(is_tf_const_op(op), "{} isn't a const op".format(op.name))
make_sure(is_tf_const_op(op), "%r isn't a const op", op.name)
value = get_tf_tensor_data(op.get_attr("value"))
if as_list:
value = value.tolist()
Expand Down Expand Up @@ -119,9 +118,6 @@ def map_tf_dtype(dtype):

def get_tf_node_attr(node, name):
"""Parser TF node attribute."""
if six.PY2:
# For python2, TF get_attr does not accept unicode
name = str(name)
return node.get_attr(name)


Expand All @@ -136,14 +132,14 @@ def tflist_to_onnx(g, shape_override):
"""

# ignore the following attributes
ignored_attr = ["unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
ignored_attr = {"unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
"TI", "Tparams", "Tindices", "Tlen", "Tdim", "Tin", "dynamic_size", "Tmultiples",
"Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "Tperm", "Tcond",
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
"Toutput_types"]
"Toutput_types"}

node_list = g.get_operations()
functions = {}
Expand Down Expand Up @@ -176,12 +172,11 @@ def tflist_to_onnx(g, shape_override):
attr_cnt[a] += 1
if a == "dtype":
attr[a] = map_tf_dtype(get_tf_node_attr(node, "dtype"))
elif a in ["T"]:
elif a == "T":
dtype = get_tf_node_attr(node, a)
if dtype:
if not isinstance(dtype, list):
dtypes[node.name] = map_tf_dtype(dtype)
elif a in ["output_type", "output_dtype", "out_type", "Tidx", "out_idx"]:
if dtype and not isinstance(dtype, list):
dtypes[node.name] = map_tf_dtype(dtype)
elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx"}:
# Tidx is used by Range
# out_idx is used by ListDiff
attr[a] = map_tf_dtype(get_tf_node_attr(node, a))
Expand All @@ -192,7 +187,7 @@ def tflist_to_onnx(g, shape_override):
elif a == "output_shapes":
# we should not need it since we pull the shapes above already
pass
elif a in ["body", "cond", "then_branch", "else_branch"]:
elif a in {"body", "cond", "then_branch", "else_branch"}:
input_shapes = [inp.get_shape() for inp in node.inputs]
nattr = get_tf_node_attr(node, a)
attr[a] = nattr.name
Expand Down