Skip to content

Commit

Permalink
Changed tf_utils pass1 to mostly use attribute type not name to deter…
Browse files Browse the repository at this point in the history
…mine how to convert.

Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft committed Nov 23, 2020
1 parent e3c0102 commit 90125fe
Showing 1 changed file with 15 additions and 26 deletions.
41 changes: 15 additions & 26 deletions tf2onnx/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
"Toutput_types", "dense_shapes", "Tdense", "Tidx", "Tsegmentids", "Tshift", "Tnumsegments",
"Toutput_types", "dense_shapes", "Tdense", "Tsegmentids", "Tshift", "Tnumsegments", "SrcT",
"Tcomplex", "Treal", # For RFFT, Tcomplex is ignored because
# onnx.helper.make_node fails,
# TODO: it should be added back.
Expand Down Expand Up @@ -353,43 +353,32 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
op_cnt[node.type] += 1
for a in node.node_def.attr:
attr_cnt[a] += 1
if a == "dtype":
attr[a] = map_tf_dtype(get_tf_node_attr(node, "dtype"))
value = get_tf_node_attr(node, a)
if a in ignored_attr:
pass
elif a == "T":
dtype = get_tf_node_attr(node, a)
if dtype and not isinstance(dtype, list):
dtypes[node.name] = map_tf_dtype(dtype)
elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx", "out_type", "internal_type",
"Tsegmentids"}:
# Tidx is used by Range
# out_idx is used by ListDiff
attr[a] = map_tf_dtype(get_tf_node_attr(node, a))
elif a == "sparse_types":
attr[a] = [map_tf_dtype(d) for d in get_tf_node_attr(node, a)]
if value and not isinstance(value, list):
dtypes[node.name] = map_tf_dtype(value)
elif a == "shape":
shape = get_tf_shape_attr(node)
if shape is not None:
attr[a] = shape
elif a == "output_shapes":
# we should not need it since we pull the shapes above already
pass
elif a in {"body", "cond", "then_branch", "else_branch", "f"}:
input_shapes = [inp.get_shape() for inp in node.inputs]
nattr = get_tf_node_attr(node, a)
attr[a] = nattr.name
functions[nattr.name] = input_shapes
elif a == "value":
tensor = get_tf_node_attr(node, a)
elif a == "DstT":
attr["to"] = map_tf_dtype(value)
elif isinstance(value, tensor_pb2.TensorProto):
if const_node_values and node.name in const_node_values:
tensor.tensor_content = const_node_values[node.name]
onnx_tensor = tf_to_onnx_tensor(tensor, name=port_name(node.name))
value.tensor_content = const_node_values[node.name]
onnx_tensor = tf_to_onnx_tensor(value, name=port_name(node.name))
attr[a] = onnx_tensor
elif a == "DstT":
attr["to"] = map_tf_dtype(get_tf_node_attr(node, "DstT"))
elif a == "SrcT":
continue
elif a in ignored_attr:
continue
elif isinstance(value, tf.DType):
attr[a] = map_tf_dtype(value)
elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], tf.DType):
attr[a] = [map_tf_dtype(v) for v in value]
else:
attr[a] = get_tf_node_attr(node, a)

Expand Down

0 comments on commit 90125fe

Please sign in to comment.