diff --git a/tf2onnx/tf_utils.py b/tf2onnx/tf_utils.py index 362e3551b..de51b66d3 100644 --- a/tf2onnx/tf_utils.py +++ b/tf2onnx/tf_utils.py @@ -13,7 +13,6 @@ from distutils.version import LooseVersion import numpy as np -import six import tensorflow as tf from tensorflow.core.framework import types_pb2, tensor_pb2 @@ -70,7 +69,7 @@ def get_tf_tensor_data(tensor): """Get data from tensor.""" make_sure(isinstance(tensor, tensor_pb2.TensorProto), "Require TensorProto") np_data = tensor_util.MakeNdarray(tensor) - make_sure(isinstance(np_data, np.ndarray), "{} isn't ndarray".format(np_data)) + make_sure(isinstance(np_data, np.ndarray), "%r isn't ndarray", np_data) return np_data @@ -83,7 +82,7 @@ def get_tf_const_value(op, as_list=True): when as_list=False, return np.array(1), type is when as_list=True, return 1, type is . """ - make_sure(is_tf_const_op(op), "{} isn't a const op".format(op.name)) + make_sure(is_tf_const_op(op), "%r isn't a const op", op.name) value = get_tf_tensor_data(op.get_attr("value")) if as_list: value = value.tolist() @@ -119,9 +118,6 @@ def map_tf_dtype(dtype): def get_tf_node_attr(node, name): """Parser TF node attribute.""" - if six.PY2: - # For python2, TF get_attr does not accept unicode - name = str(name) return node.get_attr(name) @@ -136,14 +132,14 @@ def tflist_to_onnx(g, shape_override): """ # ignore the following attributes - ignored_attr = ["unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings", + ignored_attr = {"unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings", "TI", "Tparams", "Tindices", "Tlen", "Tdim", "Tin", "dynamic_size", "Tmultiples", "Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval", "Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "Tperm", "Tcond", "T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge", "parallel_iterations", "_num_original_outputs", "output_types", "output_shapes", "key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes", - "Toutput_types"] + "Toutput_types"} node_list = g.get_operations() functions = {} @@ -176,12 +172,11 @@ def tflist_to_onnx(g, shape_override): attr_cnt[a] += 1 if a == "dtype": attr[a] = map_tf_dtype(get_tf_node_attr(node, "dtype")) - elif a in ["T"]: + elif a == "T": dtype = get_tf_node_attr(node, a) - if dtype: - if not isinstance(dtype, list): - dtypes[node.name] = map_tf_dtype(dtype) - elif a in ["output_type", "output_dtype", "out_type", "Tidx", "out_idx"]: + if dtype and not isinstance(dtype, list): + dtypes[node.name] = map_tf_dtype(dtype) + elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx"}: # Tidx is used by Range # out_idx is used by ListDiff attr[a] = map_tf_dtype(get_tf_node_attr(node, a)) @@ -192,7 +187,7 @@ def tflist_to_onnx(g, shape_override): elif a == "output_shapes": # we should not need it since we pull the shapes above already pass - elif a in ["body", "cond", "then_branch", "else_branch"]: + elif a in {"body", "cond", "then_branch", "else_branch"}: input_shapes = [inp.get_shape() for inp in node.inputs] nattr = get_tf_node_attr(node, a) attr[a] = nattr.name