Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed tf_utils pass1 to mostly use attribute type not name to determine how to convert. #1196

Merged
merged 1 commit into from
Nov 24, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions tf2onnx/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
"Toutput_types", "dense_shapes", "Tdense", "Tidx", "Tsegmentids", "Tshift", "Tnumsegments",
"Toutput_types", "dense_shapes", "Tdense", "Tsegmentids", "Tshift", "Tnumsegments", "SrcT",
"Tcomplex", "Treal", # For RFFT, Tcomplex is ignored because
# onnx.helper.make_node fails,
# TODO: it should be added back.
Expand Down Expand Up @@ -353,43 +353,32 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
op_cnt[node.type] += 1
for a in node.node_def.attr:
attr_cnt[a] += 1
if a == "dtype":
attr[a] = map_tf_dtype(get_tf_node_attr(node, "dtype"))
value = get_tf_node_attr(node, a)
if a in ignored_attr:
pass
elif a == "T":
dtype = get_tf_node_attr(node, a)
if dtype and not isinstance(dtype, list):
dtypes[node.name] = map_tf_dtype(dtype)
elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx", "out_type", "internal_type",
"Tsegmentids"}:
# Tidx is used by Range
# out_idx is used by ListDiff
attr[a] = map_tf_dtype(get_tf_node_attr(node, a))
elif a == "sparse_types":
attr[a] = [map_tf_dtype(d) for d in get_tf_node_attr(node, a)]
if value and not isinstance(value, list):
dtypes[node.name] = map_tf_dtype(value)
elif a == "shape":
shape = get_tf_shape_attr(node)
if shape is not None:
attr[a] = shape
elif a == "output_shapes":
# we should not need it since we pull the shapes above already
pass
elif a in {"body", "cond", "then_branch", "else_branch", "f"}:
input_shapes = [inp.get_shape() for inp in node.inputs]
nattr = get_tf_node_attr(node, a)
attr[a] = nattr.name
functions[nattr.name] = input_shapes
elif a == "value":
tensor = get_tf_node_attr(node, a)
elif a == "DstT":
attr["to"] = map_tf_dtype(value)
elif isinstance(value, tensor_pb2.TensorProto):
if const_node_values and node.name in const_node_values:
tensor.tensor_content = const_node_values[node.name]
onnx_tensor = tf_to_onnx_tensor(tensor, name=port_name(node.name))
value.tensor_content = const_node_values[node.name]
onnx_tensor = tf_to_onnx_tensor(value, name=port_name(node.name))
attr[a] = onnx_tensor
elif a == "DstT":
attr["to"] = map_tf_dtype(get_tf_node_attr(node, "DstT"))
elif a == "SrcT":
continue
elif a in ignored_attr:
continue
elif isinstance(value, tf.DType):
attr[a] = map_tf_dtype(value)
elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], tf.DType):
attr[a] = [map_tf_dtype(v) for v in value]
else:
attr[a] = get_tf_node_attr(node, a)

Expand Down