diff --git a/tools/onnx-graphsurgeon/CHANGELOG.md b/tools/onnx-graphsurgeon/CHANGELOG.md index b2ed0ba3..b5bdf89f 100644 --- a/tools/onnx-graphsurgeon/CHANGELOG.md +++ b/tools/onnx-graphsurgeon/CHANGELOG.md @@ -3,9 +3,17 @@ Dates are in YYYY-MM-DD format. +## v0.3.11 (2021-07-14) +### Changed +- Updated `fold_constants()` so that it no longer fails if a shape folding pass fails when `error_ok` is `True`. + +### Fixed +- Fixed a bug where `fold_constants()` would fail if a model contained a `Slice` node without a `starts` or `ends` input. + + ## v0.3.10 (2021-05-20) ### Added -- Added support for folding `Shape -> Slice` patterns even when the entire shape may not be known. +- Added support for folding `Shape -> Slice` patterns even when the entire shape may not be known. ## v0.3.9 (2021-04-20) diff --git a/tools/onnx-graphsurgeon/docs/conf.py b/tools/onnx-graphsurgeon/docs/conf.py index c77937f0..4650ae1c 100644 --- a/tools/onnx-graphsurgeon/docs/conf.py +++ b/tools/onnx-graphsurgeon/docs/conf.py @@ -15,16 +15,17 @@ # import sys import os + ROOT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), os.path.pardir) sys.path.insert(0, ROOT_DIR) import onnx_graphsurgeon as gs extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.autosummary', - 'sphinx.ext.napoleon', - 'sphinx.ext.mathjax', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.mathjax", ] # Want to be able to generate docs with no dependencies installed @@ -42,50 +43,48 @@ autosummary_generate = True -source_suffix = ['.rst'] +source_suffix = [".rst"] # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'ONNX GraphSurgeon' -copyright = '2020, NVIDIA' -author = 'NVIDIA' +project = "ONNX GraphSurgeon" +copyright = "2020, NVIDIA" +author = "NVIDIA" version = gs.__version__ # The full version, including alpha/beta/rc tags. release = version # Style -pygments_style = 'colorful' +pygments_style = "colorful" -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Use the TRT theme and NVIDIA logo -html_static_path = ['_static'] +html_static_path = ["_static"] -html_logo = '_static/img/nvlogo_white.png' +html_logo = "_static/img/nvlogo_white.png" # Hide source link html_show_sourcelink = False # Output file base name for HTML help builder. -htmlhelp_basename = 'OnnxGraphSurgeonDoc' +htmlhelp_basename = "OnnxGraphSurgeonDoc" # Template files to extend default Sphinx templates. # See https://www.sphinx-doc.org/en/master/templating.html for details. templates_path = ["_templates"] # For constructor arguments to show up in Sphinx generated doc -autoclass_content = 'both' +autoclass_content = "both" # Unlimited depth sidebar. -html_theme_options = { - 'navigation_depth': -1 -} +html_theme_options = {"navigation_depth": -1} -html_sidebars = { '**': ['globaltoc.html', 'relations.html', 'sourcelink.html', 'searchbox.html'] } +html_sidebars = {"**": ["globaltoc.html", "relations.html", "sourcelink.html", "searchbox.html"]} # Allows us to override the default page width in the Sphinx theme. def setup(app): - app.add_css_file('style.css') + app.add_css_file("style.css") diff --git a/tools/onnx-graphsurgeon/examples/07_creating_a_model_with_the_layer_api/generate.py b/tools/onnx-graphsurgeon/examples/07_creating_a_model_with_the_layer_api/generate.py index 411ee2f5..32335ce1 100644 --- a/tools/onnx-graphsurgeon/examples/07_creating_a_model_with_the_layer_api/generate.py +++ b/tools/onnx-graphsurgeon/examples/07_creating_a_model_with_the_layer_api/generate.py @@ -50,7 +50,9 @@ def mul(self, a, b): @gs.Graph.register() def gemm(self, a, b, trans_a=False, trans_b=False): attrs = {"transA": int(trans_a), "transB": int(trans_b)} - return propagate_dtype(self.layer(op="Gemm", inputs=[a, b], outputs=["gemm_out_gs"], attrs=attrs), a.dtype or b.dtype) + return propagate_dtype( + self.layer(op="Gemm", inputs=[a, b], outputs=["gemm_out_gs"], attrs=attrs), a.dtype or b.dtype + ) # You can also specify a set of opsets when regsitering a function. diff --git a/tools/onnx-graphsurgeon/examples/08_replacing_a_subgraph/generate.py b/tools/onnx-graphsurgeon/examples/08_replacing_a_subgraph/generate.py index a8c17059..47a1cbdc 100644 --- a/tools/onnx-graphsurgeon/examples/08_replacing_a_subgraph/generate.py +++ b/tools/onnx-graphsurgeon/examples/08_replacing_a_subgraph/generate.py @@ -23,10 +23,12 @@ def min(self, *args): return self.layer(op="Min", inputs=args, outputs=["min_out"])[0] + @gs.Graph.register() def max(self, *args): return self.layer(op="Max", inputs=args, outputs=["max_out"])[0] + @gs.Graph.register() def identity(self, inp): return self.layer(op="Identity", inputs=[inp], outputs=["identity_out"])[0] @@ -44,7 +46,9 @@ def identity(self, inp): # Add identity nodes to make the graph structure a bit more interesting inp = graph.identity(graph.inputs[0]) max_out = graph.max(graph.min(inp, MAX_VAL), MIN_VAL) -graph.outputs = [graph.identity(max_out), ] +graph.outputs = [ + graph.identity(max_out), +] # Graph outputs must include dtype information graph.outputs[0].to_variable(dtype=np.float32, shape=(4, 4)) diff --git a/tools/onnx-graphsurgeon/examples/09_shape_operations_with_the_layer_api/generate.py b/tools/onnx-graphsurgeon/examples/09_shape_operations_with_the_layer_api/generate.py index 752d3611..c5d92b0d 100644 --- a/tools/onnx-graphsurgeon/examples/09_shape_operations_with_the_layer_api/generate.py +++ b/tools/onnx-graphsurgeon/examples/09_shape_operations_with_the_layer_api/generate.py @@ -29,7 +29,9 @@ def shape(self, a): @gs.Graph.register() def reduce_prod(self, a, axes, keepdims=True): - return self.layer(op="ReduceProd", inputs=[a], attrs={"axes": axes, "keepdims": int(keepdims)}, outputs=["reduce_prod_out_gs"])[0] + return self.layer( + op="ReduceProd", inputs=[a], attrs={"axes": axes, "keepdims": int(keepdims)}, outputs=["reduce_prod_out_gs"] + )[0] @gs.Graph.register() @@ -69,8 +71,8 @@ def concat(self, inputs, axis=0): partially_flattened = graph.reshape(graph.inputs[0], new_shape) # Finally, set up the outputs and export. -flattened.name = "flattened" # Rename output tensor to make it easy to find. -flattened.dtype = np.float32 # NOTE: We must include dtype information for graph outputs +flattened.name = "flattened" # Rename output tensor to make it easy to find. +flattened.dtype = np.float32 # NOTE: We must include dtype information for graph outputs partially_flattened.name = "partially_flattened" partially_flattened.dtype = np.float32 diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/__init__.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/__init__.py index be1fe448..59cd747e 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/__init__.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/__init__.py @@ -5,4 +5,4 @@ from onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable from onnx_graphsurgeon.util.exception import OnnxGraphSurgeonException -__version__ = "0.3.10" +__version__ = "0.3.11" diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/base_exporter.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/base_exporter.py index c434b68b..e9008344 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/base_exporter.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/base_exporter.py @@ -16,6 +16,7 @@ from onnx_graphsurgeon.ir.graph import Graph + class BaseExporter(object): @staticmethod def export_graph(graph: Graph): diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/onnx_exporter.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/onnx_exporter.py index 43611fb1..6cc68398 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/onnx_exporter.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/onnx_exporter.py @@ -42,11 +42,14 @@ def export_tensor_proto(tensor: Constant) -> onnx.TensorProto: onnx_tensor.name = tensor.name return onnx_tensor - @staticmethod def export_value_info_proto(tensor: Variable, do_type_check: bool) -> onnx.ValueInfoProto: if do_type_check and tensor.dtype is None: - G_LOGGER.critical("Graph input and output tensors must include dtype information. Please set the dtype attribute for: {:}".format(tensor)) + G_LOGGER.critical( + "Graph input and output tensors must include dtype information. Please set the dtype attribute for: {:}".format( + tensor + ) + ) if tensor.dtype is not None: onnx_tensor = onnx.helper.make_tensor_value_info(tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape) @@ -54,11 +57,12 @@ def export_value_info_proto(tensor: Variable, do_type_check: bool) -> onnx.Value onnx_tensor = onnx.helper.make_empty_tensor_value_info(tensor.name) return onnx_tensor - @staticmethod def export_node(node: Node, do_type_check: bool) -> onnx.NodeProto: # Cannot pass in attrs directly as make_node will change the order - onnx_node = onnx.helper.make_node(node.op, inputs=[t.name for t in node.inputs], outputs=[t.name for t in node.outputs], name=node.name) + onnx_node = onnx.helper.make_node( + node.op, inputs=[t.name for t in node.inputs], outputs=[t.name for t in node.outputs], name=node.name + ) # Convert Tensors and Graphs to TensorProtos and GraphProtos respectively for key, val in node.attrs.items(): if isinstance(val, Tensor): @@ -68,7 +72,6 @@ def export_node(node: Node, do_type_check: bool) -> onnx.NodeProto: onnx_node.attribute.extend([onnx.helper.make_attribute(key, val)]) return onnx_node - @staticmethod def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto: """ @@ -83,7 +86,9 @@ def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto: inputs = [OnnxExporter.export_value_info_proto(inp, do_type_check) for inp in graph.inputs] outputs = [OnnxExporter.export_value_info_proto(out, do_type_check) for out in graph.outputs] tensor_map = graph.tensors() - initializer = [OnnxExporter.export_tensor_proto(tensor) for tensor in tensor_map.values() if isinstance(tensor, Constant)] + initializer = [ + OnnxExporter.export_tensor_proto(tensor) for tensor in tensor_map.values() if isinstance(tensor, Constant) + ] # Remove inputs and outputs to export ValueInfoProtos for tensor in graph.inputs + graph.outputs: @@ -93,9 +98,22 @@ def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto: # Omit tensors from value_info if we don't know their shape/dtype def has_value_info(tensor): return isinstance(tensor, Variable) and (tensor.dtype is not None or tensor.shape is not None) - value_info = [OnnxExporter.export_value_info_proto(tensor, do_type_check) for tensor in tensor_map.values() if has_value_info(tensor)] - return onnx.helper.make_graph(nodes=nodes, name=graph.name, inputs=inputs, outputs=outputs, initializer=initializer, doc_string=graph.doc_string, value_info=value_info) + value_info = [ + OnnxExporter.export_value_info_proto(tensor, do_type_check) + for tensor in tensor_map.values() + if has_value_info(tensor) + ] + + return onnx.helper.make_graph( + nodes=nodes, + name=graph.name, + inputs=inputs, + outputs=outputs, + initializer=initializer, + doc_string=graph.doc_string, + value_info=value_info, + ) def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto": diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/importers/base_importer.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/importers/base_importer.py index 832a74b1..5304a2aa 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/importers/base_importer.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/importers/base_importer.py @@ -16,6 +16,7 @@ from onnx_graphsurgeon.ir.graph import Graph + class BaseImporter(object): @staticmethod def import_graph(graph) -> Graph: diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/importers/onnx_importer.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/importers/onnx_importer.py index 50244912..783e6e15 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/importers/onnx_importer.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/importers/onnx_importer.py @@ -43,6 +43,7 @@ "STRINGS": "strings", } + def get_onnx_tensor_shape(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> List[int]: shape = None if isinstance(onnx_tensor, onnx.TensorProto): @@ -69,6 +70,7 @@ def get_onnx_tensor_dtype(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorPro return onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[onnx_type] return None + class OnnxImporter(BaseImporter): @staticmethod def get_opset(model: onnx.ModelProto): @@ -82,26 +84,32 @@ def get_opset(model: onnx.ModelProto): G_LOGGER.warning("Model does not contain opset information! Using default opset.") return None - @staticmethod def get_import_domains(model: onnx.ModelProto): return model.opset_import - @staticmethod def import_tensor(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> Tensor: if isinstance(onnx_tensor, onnx.TensorProto): data_location = int(onnx_tensor.data_location) if onnx_tensor.HasField("data_location") else None return Constant(name=onnx_tensor.name, values=LazyValues(onnx_tensor), data_location=data_location) else: - return Variable(name=onnx_tensor.name, dtype=get_onnx_tensor_dtype(onnx_tensor), shape=get_onnx_tensor_shape(onnx_tensor)) - + return Variable( + name=onnx_tensor.name, + dtype=get_onnx_tensor_dtype(onnx_tensor), + shape=get_onnx_tensor_shape(onnx_tensor), + ) @staticmethod - def import_node(onnx_node: onnx.NodeProto, tensor_map: "OrderedDict[str, Tensor]", subgraph_tensor_map: "OrderedDict[str, Tensor]") -> Node: + def import_node( + onnx_node: onnx.NodeProto, + tensor_map: "OrderedDict[str, Tensor]", + subgraph_tensor_map: "OrderedDict[str, Tensor]", + ) -> Node: def attrs_to_dict(attrs): attr_dict = OrderedDict() for attr in attrs: + def process_attr(attr_str: str): processed = getattr(attr, ONNX_PYTHON_ATTR_MAPPING[attr_str]) if attr_str == "STRING": @@ -109,7 +117,9 @@ def process_attr(attr_str: str): elif attr_str == "TENSOR": processed = OnnxImporter.import_tensor(processed) elif attr_str == "GRAPH": - processed = OnnxImporter.import_graph(processed, misc.combine_dicts(tensor_map, subgraph_tensor_map)) + processed = OnnxImporter.import_graph( + processed, misc.combine_dicts(tensor_map, subgraph_tensor_map) + ) elif attr_str == "FLOATS" or attr_str == "INTS": processed = list(processed) elif attr_str == "STRINGS": @@ -121,9 +131,15 @@ def process_attr(attr_str: str): if attr_str in ONNX_PYTHON_ATTR_MAPPING: attr_dict[attr.name] = process_attr(attr_str) else: - G_LOGGER.warning("Attribute of type {:} is currently unsupported. Skipping attribute.".format(attr_str)) + G_LOGGER.warning( + "Attribute of type {:} is currently unsupported. Skipping attribute.".format(attr_str) + ) else: - G_LOGGER.warning("Attribute type: {:} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute.".format(attr.type)) + G_LOGGER.warning( + "Attribute type: {:} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute.".format( + attr.type + ) + ) return attr_dict # Optional inputs/outputs are represented by empty tensors. All other tensors should already have been populated during shape inference. @@ -140,31 +156,43 @@ def get_tensor(name: str, check_outer_graph=True): G_LOGGER.verbose("Generating empty tensor") return Variable.empty() - G_LOGGER.verbose("Tensor: {:} was not generated during shape inference, or shape inference was not run on this model. Creating a new Tensor.".format(name)) + G_LOGGER.verbose( + "Tensor: {:} was not generated during shape inference, or shape inference was not run on this model. Creating a new Tensor.".format( + name + ) + ) subgraph_tensor_map[name] = Variable(name) return subgraph_tensor_map[name] - # Retrieve Tensors for node inputs/outputs. Only empty tensors should need to be newly added. def retrieve_node_inputs() -> List[Tensor]: - inputs = [] # List[Tensor] + inputs = [] # List[Tensor] for input_name in onnx_node.input: inputs.append(get_tensor(input_name)) return inputs def retrieve_node_outputs() -> List[Tensor]: - outputs = [] # List[Tensor] + outputs = [] # List[Tensor] for output_name in onnx_node.output: # Node outputs cannot come from the outer graph, they must be created within the inner graph. outputs.append(get_tensor(output_name, check_outer_graph=False)) return outputs - return Node(op=onnx_node.op_type, name=onnx_node.name, attrs=attrs_to_dict(onnx_node.attribute), inputs=retrieve_node_inputs(), outputs=retrieve_node_outputs()) - - + return Node( + op=onnx_node.op_type, + name=onnx_node.name, + attrs=attrs_to_dict(onnx_node.attribute), + inputs=retrieve_node_inputs(), + outputs=retrieve_node_outputs(), + ) @staticmethod - def import_graph(onnx_graph: onnx.GraphProto, tensor_map: "OrderedDict[str, Tensor]"=None, opset=None, import_domains: onnx.OperatorSetIdProto=None) -> Graph: + def import_graph( + onnx_graph: onnx.GraphProto, + tensor_map: "OrderedDict[str, Tensor]" = None, + opset=None, + import_domains: onnx.OperatorSetIdProto = None, + ) -> Graph: """ Imports a Graph from an ONNX Graph. @@ -174,20 +202,26 @@ def import_graph(onnx_graph: onnx.GraphProto, tensor_map: "OrderedDict[str, Tens tensor_map (OrderedDict[str, Tensor]): A mapping of tensor names to Tensors. This is generally only useful for subgraph import. opset (int): The ONNX opset to use for this graph. """ - tensor_map = copy.copy(misc.default_value(tensor_map, OrderedDict())) # Outer graph tensors, read-only - subgraph_tensor_map = OrderedDict() # Tensors in this subgraph + tensor_map = copy.copy(misc.default_value(tensor_map, OrderedDict())) # Outer graph tensors, read-only + subgraph_tensor_map = OrderedDict() # Tensors in this subgraph # Retrieves a Tensor from subgraph_tensor_map or the outer graph (tensor_map) if present, otherwise imports the tensor # If overwrite=True, this function will overwrite previously imported tensors # if the new tensor has more information available. - def get_tensor(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto], overwrite=False, check_outer_graph=True) -> Tensor: + def get_tensor( + onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto], overwrite=False, check_outer_graph=True + ) -> Tensor: # Prioritize the subgraph even if check_outer_graph is set if onnx_tensor.name in subgraph_tensor_map: if overwrite: tensor = OnnxImporter.import_tensor(onnx_tensor) if isinstance(subgraph_tensor_map[onnx_tensor.name], Variable): - subgraph_tensor_map[onnx_tensor.name].dtype = subgraph_tensor_map[onnx_tensor.name].dtype or tensor.dtype - subgraph_tensor_map[onnx_tensor.name].shape = subgraph_tensor_map[onnx_tensor.name].shape or tensor.shape + subgraph_tensor_map[onnx_tensor.name].dtype = ( + subgraph_tensor_map[onnx_tensor.name].dtype or tensor.dtype + ) + subgraph_tensor_map[onnx_tensor.name].shape = ( + subgraph_tensor_map[onnx_tensor.name].shape or tensor.shape + ) return subgraph_tensor_map[onnx_tensor.name] if check_outer_graph and onnx_tensor.name in tensor_map: @@ -196,7 +230,6 @@ def get_tensor(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto], overwr subgraph_tensor_map[onnx_tensor.name] = OnnxImporter.import_tensor(onnx_tensor) return subgraph_tensor_map[onnx_tensor.name] - # Import initializers contents into Constants. G_LOGGER.verbose("Importing initializers") for initializer in onnx_graph.initializer: @@ -213,25 +246,33 @@ def get_tensor(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto], overwr # Graph inputs and outputs can never come from the outer graph! initializer_names = set([tensor.name for tensor in onnx_graph.initializer]) G_LOGGER.verbose("Importing graph inputs") - graph_inputs = [] # List[Tensor] + graph_inputs = [] # List[Tensor] for inp in onnx_graph.input: if inp.name not in initializer_names: tensor = get_tensor(inp, check_outer_graph=False) graph_inputs.append(tensor) G_LOGGER.verbose("Importing graph outputs") - graph_outputs = [] # List[Tensor] + graph_outputs = [] # List[Tensor] for out in onnx_graph.output: tensor = get_tensor(out, check_outer_graph=False) graph_outputs.append(tensor) G_LOGGER.verbose("Importing nodes") - nodes = [] # List[Node] + nodes = [] # List[Node] for onnx_node in onnx_graph.node: node = OnnxImporter.import_node(onnx_node, tensor_map, subgraph_tensor_map) nodes.append(node) - return Graph(nodes=nodes, inputs=graph_inputs, outputs=graph_outputs, name=onnx_graph.name, doc_string=onnx_graph.doc_string, opset=opset, import_domains=import_domains) + return Graph( + nodes=nodes, + inputs=graph_inputs, + outputs=graph_outputs, + name=onnx_graph.name, + doc_string=onnx_graph.doc_string, + opset=opset, + import_domains=import_domains, + ) def import_onnx(onnx_model: "onnx.ModelProto") -> Graph: @@ -244,4 +285,8 @@ def import_onnx(onnx_model: "onnx.ModelProto") -> Graph: Returns: Graph: A corresponding onnx-graphsurgeon Graph. """ - return OnnxImporter.import_graph(onnx_model.graph, opset=OnnxImporter.get_opset(onnx_model), import_domains=OnnxImporter.get_import_domains(onnx_model)) + return OnnxImporter.import_graph( + onnx_model.graph, + opset=OnnxImporter.get_opset(onnx_model), + import_domains=OnnxImporter.get_import_domains(onnx_model), + ) diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py index 6c108a6f..bf7604b4 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py @@ -44,10 +44,10 @@ class Graph(object): """ Represents a graph containing nodes and tensors. """ - DEFAULT_OPSET = 11 - OPSET_FUNC_MAP = defaultdict(dict) # Ops registered for specific opsets. - GLOBAL_FUNC_MAP = dict() # Ops registered for ALL opsets. + DEFAULT_OPSET = 11 + OPSET_FUNC_MAP = defaultdict(dict) # Ops registered for specific opsets. + GLOBAL_FUNC_MAP = dict() # Ops registered for ALL opsets. @staticmethod def register(opsets=None): @@ -72,10 +72,13 @@ def add(self, a, b): function previously registered for those opsets. By default, the function is registered for all opsets. """ + def register_func(func): if hasattr(Graph, func.__name__): - G_LOGGER.warning("Registered function: {:} is hidden by a Graph attribute or function with the same name. " - "This function will never be called!".format(func.__name__)) + G_LOGGER.warning( + "Registered function: {:} is hidden by a Graph attribute or function with the same name. " + "This function will never be called!".format(func.__name__) + ) # Default behavior is to register functions for all opsets. if opsets is None: @@ -84,10 +87,19 @@ def register_func(func): for opset in opsets: Graph.OPSET_FUNC_MAP[opset][func.__name__] = func return func - return register_func + return register_func - def __init__(self, nodes: Sequence[Node]=None, inputs: Sequence[Tensor]=None, outputs: Sequence[Tensor]=None, name=None, doc_string=None, opset=None, import_domains=None): + def __init__( + self, + nodes: Sequence[Node] = None, + inputs: Sequence[Tensor] = None, + outputs: Sequence[Tensor] = None, + name=None, + doc_string=None, + opset=None, + import_domains=None, + ): """ Args: nodes (Sequence[Node]): A list of the nodes in this graph. @@ -111,7 +123,6 @@ def __init__(self, nodes: Sequence[Node]=None, inputs: Sequence[Tensor]=None, ou # For layer() function self.name_idx = 0 - def __getattr__(self, name): try: return super().__getattribute__(name) @@ -126,21 +137,24 @@ def __getattr__(self, name): G_LOGGER.error("No function: {:} registered for opset: {:}".format(name, self.opset)) raise err - def __setattr__(self, name, value): # We don't want graph inputs/outputs to be SynchronizedLists if name in ["inputs", "outputs"]: value = list(value) return super().__setattr__(name, value) - def __eq__(self, other: "Graph"): - nodes_match = len(self.nodes) == len(other.nodes) and all([node == other_node for node, other_node in zip(self.nodes, other.nodes)]) - inputs_match = len(self.inputs) == len(other.inputs) and all([inp == other_inp for inp, other_inp in zip(self.inputs, other.inputs)]) - outputs_match = len(self.outputs) == len(other.outputs) and all([out == other_out for out, other_out in zip(self.outputs, other.outputs)]) + nodes_match = len(self.nodes) == len(other.nodes) and all( + [node == other_node for node, other_node in zip(self.nodes, other.nodes)] + ) + inputs_match = len(self.inputs) == len(other.inputs) and all( + [inp == other_inp for inp, other_inp in zip(self.inputs, other.inputs)] + ) + outputs_match = len(self.outputs) == len(other.outputs) and all( + [out == other_out for out, other_out in zip(self.outputs, other.outputs)] + ) return nodes_match and inputs_match and outputs_match - def node_ids(self): """ Returns a context manager that supplies unique integer IDs for Nodes in the Graph. @@ -156,14 +170,14 @@ def node_ids(self): """ return NodeIDAdder(self) - def _get_node_id(self, node): try: return node.id except AttributeError: - G_LOGGER.critical("Encountered a node not in the graph:\n{:}.\n\n" - "To fix this, please append the node to this graph's `nodes` attribute.".format(node)) - + G_LOGGER.critical( + "Encountered a node not in the graph:\n{:}.\n\n" + "To fix this, please append the node to this graph's `nodes` attribute.".format(node) + ) # A tensor is local if it is produced in this graph, or is explicitly a graph input. def _local_tensors(self): @@ -172,7 +186,6 @@ def _local_tensors(self): local_tensors.update({t.name: t for t in self.tensors().values() if isinstance(t, Constant)}) return local_tensors - # Returns tensors used by this graph which are not present in the graph. # These may come from an outer graph for example. def _foreign_tensors(self): @@ -190,15 +203,12 @@ def is_foreign_tensor(tensor): subgraph_foreign_tensors = attr._foreign_tensors() # Some of the foreign tensors from a subgraph may come from this graph. subgraph_foreign_tensors = { - t.name: t - for t in subgraph_foreign_tensors.values() - if is_foreign_tensor(t) + t.name: t for t in subgraph_foreign_tensors.values() if is_foreign_tensor(t) } foreign_tensors.update(subgraph_foreign_tensors) return foreign_tensors - def _get_used_node_ids(self): local_tensors = self._local_tensors() @@ -209,7 +219,6 @@ def __init__(self, initial_tensors=None): tensors = misc.default_value(initial_tensors, []) self.seen_tensors = set([tensor.name for tensor in tensors]) - def __call__(self, tensor): # Returns True if a tensor should included, # False if it should be filtered out. @@ -222,7 +231,6 @@ def __call__(self, tensor): return True return False - # Traverse backwards from outputs to find all used nodes. ignore_tensors = IgnoreDupAndForeign() used_tensors = list(filter(ignore_tensors, self.outputs)) @@ -245,7 +253,6 @@ def __call__(self, tensor): used_tensors.extend(filter(ignore_tensors, node_used_tensors)) return used_node_ids, used_tensors - def cleanup(self, remove_unused_node_outputs=False, recurse_subgraphs=True, remove_unused_graph_inputs=False): """ Removes unused nodes and tensors from the graph. @@ -269,13 +276,15 @@ def cleanup(self, remove_unused_node_outputs=False, recurse_subgraphs=True, remo Returns: self """ + def cleanup_subgraphs(): for node in self.nodes: for attr in node.attrs.values(): if isinstance(attr, Graph): - attr.cleanup(remove_unused_node_outputs=remove_unused_node_outputs, - remove_unused_graph_inputs=remove_unused_graph_inputs) - + attr.cleanup( + remove_unused_node_outputs=remove_unused_node_outputs, + remove_unused_graph_inputs=remove_unused_graph_inputs, + ) if recurse_subgraphs: cleanup_subgraphs() @@ -310,8 +319,11 @@ def cleanup_subgraphs(): if remove_unused_node_outputs: graph_output_names = set([tensor.name for tensor in self.outputs]) for node in nodes: + def is_hanging_tensor(tensor): - return not tensor.is_empty() and len(tensor.outputs) == 0 and tensor.name not in graph_output_names + return ( + not tensor.is_empty() and len(tensor.outputs) == 0 and tensor.name not in graph_output_names + ) to_remove = [out for out in node.outputs if is_hanging_tensor(out)] for out in to_remove: @@ -321,7 +333,6 @@ def is_hanging_tensor(tensor): self.nodes = nodes return self - def toposort(self, recurse_subgraphs=True): """ Topologically sort the graph in place. @@ -352,7 +363,7 @@ def __init__(self, node=None, level=None): def __lt__(self, other): return self.level < other.level - hierarchy_levels = {} # Dict[int, HierarchyDescriptor] + hierarchy_levels = {} # Dict[int, HierarchyDescriptor] local_tensors = self._local_tensors() @@ -384,7 +395,6 @@ def get_input_nodes(node): self.nodes = [hd.node for hd in sorted(hierarchy_levels.values())] return self - def tensors(self, check_duplicates=False): """ Creates a tensor map of all the tensors used by this graph by walking over all nodes. Empty tensors are omitted from this map. @@ -405,12 +415,14 @@ def tensors(self, check_duplicates=False): def add_to_tensor_map(tensor): if not tensor.is_empty(): if check_duplicates and tensor.name in tensor_map and not (tensor_map[tensor.name] is tensor): - G_LOGGER.critical("Found distinct tensors that share the same name:\n[id: {:}] {:}\n[id: {:}] {:}" - .format(id(tensor_map[tensor.name]), tensor_map[tensor.name], id(tensor), tensor)) + G_LOGGER.critical( + "Found distinct tensors that share the same name:\n[id: {:}] {:}\n[id: {:}] {:}".format( + id(tensor_map[tensor.name]), tensor_map[tensor.name], id(tensor), tensor + ) + ) tensor_map[tensor.name] = tensor - # I/O tensors may not be attached to nodes. for io_tensor in self.inputs: add_to_tensor_map(io_tensor) @@ -424,7 +436,6 @@ def add_to_tensor_map(tensor): return tensor_map - def fold_constants(self, fold_shapes=True, recurse_subgraphs=True, partitioning=None, error_ok=True): """ Folds constants in-place in the graph. The graph must be topologically sorted prior to @@ -496,7 +507,6 @@ def all_tensors_const(tensors): all_subgraph_foreign_tensors_const &= all_tensors_const(foreign_tensors) return all_subgraph_foreign_tensors_const - # Walks along the outputs of graph_constants to see if they can also be computed statically. # Since the graph is topologically sorted, this should find all constant nodes in the graph. for node in graph_clone.nodes: @@ -513,12 +523,13 @@ def all_tensors_const(tensors): if len(tensor.inputs) == 1: node = tensor.inputs[0] if node.op == "Constant": - graph_constants[tensor.name] = tensor.to_constant(node.attrs["value"]._values) # Using ._values avoids copying + graph_constants[tensor.name] = tensor.to_constant( + node.attrs["value"]._values + ) # Using ._values avoids copying graph_constants[tensor.name].inputs.clear() graph_constants = update_foldable_outputs(graph_constants) - # Pass 2: Shape Folding def get_producer(tensor, op): @@ -533,7 +544,6 @@ def get_producer(tensor, op): return None return node - def get_input(node, index=0): """ Get the input tensor of a node iff the input tensor is not already marked a graph constant. @@ -549,8 +559,16 @@ def get_input(node, index=0): return inp + def get_scalar_value(tensor): + """ + Gets the scalar value of a tensor with a single item + """ + if not tensor.shape: + return tensor.values + else: + return list(tensor.values)[0] - def handle_shape(tensor): + def fold_shape(tensor): inp = get_input(get_producer(tensor, "Shape")) if inp is None: return None @@ -559,8 +577,7 @@ def handle_shape(tensor): return None return np.array(inp.shape, dtype=np.int64) - - def handle_shape_gather(tensor): + def fold_shape_gather(tensor): gather = get_producer(tensor, "Gather") if gather is None: return None @@ -576,7 +593,7 @@ def handle_shape_gather(tensor): return None indices = indices_tensor.values - if not indices.shape: # Scalar-case + if not indices.shape: # Scalar-case shape = inp.shape[int(indices)] if misc.is_dynamic_dimension(shape): return None @@ -587,34 +604,37 @@ def handle_shape_gather(tensor): return np.array(shape, dtype=np.int64) - - def handle_shape_slice(tensor): + def fold_shape_slice(tensor): slice = get_producer(tensor, "Slice") if slice is None: return None data = slice.inputs[0] - starts, ends = slice.inputs[1:3] - inp = get_input(get_producer(data, "Shape")) - if inp is None or inp.shape is None: + if len(slice.inputs) >= 3: + starts, ends = slice.inputs[1:3] + if any(not isinstance(t, Constant) for t in [starts, ends]): + return None + starts, ends = get_scalar_value(starts), get_scalar_value(ends) + elif "starts" in slice.attrs and "ends" in slice.attrs: + starts, ends = slice.attrs["starts"][0], slice.attrs["ends"][0] + else: return None - if any(not isinstance(t, Constant) for t in [starts, ends]): + inp = get_input(get_producer(data, "Shape")) + if inp is None or inp.shape is None: return None - def get_value(tensor): # Gets the integer value of a tensor with a single item - if not tensor.shape: - return tensor.values - else: - return list(tensor.values)[0] - + # For shape tensors, we can only slice on the 0th dimension. if len(slice.inputs) > 3: axes = slice.inputs[3] if not isinstance(axes, Constant): return None - if get_value(axes) != 0: + if get_scalar_value(axes) != 0: + return None + elif "axes" in slice.attrs: + if slice.attrs["axes"][0] != 0: return None steps = 1 @@ -622,36 +642,34 @@ def get_value(tensor): # Gets the integer value of a tensor with a single item steps = slice.inputs[4] if not isinstance(steps, Constant): return None + steps = get_scalar_value(steps) + elif "steps" in slice.attrs: + steps = slice.attrs["steps"][0] - steps = get_value(steps) - - shape = inp.shape[get_value(starts):get_value(ends):steps] + shape = inp.shape[starts:ends:steps] if misc.is_dynamic_shape(shape): return None - - return np.array(shape, dtype=np.int64) - - - # Finds the static shape of a shape node output if possible, otherwise returns None. - def lower_shape(tensor): - SHAPE_FOLD_FUNCS = [handle_shape_gather, handle_shape_slice, handle_shape] - for fold_func in SHAPE_FOLD_FUNCS: - shape = fold_func(tensor) - if shape is not None: - return shape + return np.array(shape, dtype=np.int64) if fold_shapes: - for tensor in clone_tensors.values(): - shape_of = lower_shape(tensor) - - if shape_of is not None: - G_LOGGER.ultra_verbose("Folding shape tensor: {:} to: {:}".format(tensor.name, shape_of)) - graph_constants[tensor.name] = tensor.to_constant(shape_of) - graph_constants[tensor.name].inputs.clear() - - graph_constants = update_foldable_outputs(graph_constants) + # NOTE: The order of shape folding passes is important to maximize how much we fold (phase-ordering problem). + SHAPE_FOLD_FUNCS = [fold_shape_gather, fold_shape_slice, fold_shape] + for shape_fold_func in SHAPE_FOLD_FUNCS: + try: + for tensor in clone_tensors.values(): + shape_of = shape_fold_func(tensor) + if shape_of is not None: + G_LOGGER.ultra_verbose("Folding shape tensor: {:} to: {:}".format(tensor.name, shape_of)) + graph_constants[tensor.name] = tensor.to_constant(shape_of) + graph_constants[tensor.name].inputs.clear() + except Exception as err: + if not error_ok: + raise err + G_LOGGER.warning("'{:}' routine failed with:\n{:}".format(shape_fold_func.__name__, err)) + else: + graph_constants = update_foldable_outputs(graph_constants) def partition_and_infer(subgraph): def get_out_node_ids(): @@ -668,7 +686,7 @@ def get_out_node_ids(): out_node_ids = get_out_node_ids() constant_values = {} - for index in out_node_ids: # Have to use index since 'node' is not in part + for index in out_node_ids: # Have to use index since 'node' is not in part part = subgraph.copy() out_node = part.nodes[index] part.outputs = out_node.outputs @@ -701,7 +719,6 @@ def get_out_node_ids(): return constant_values - # Next, evaluate the foldable variables with ONNX-Runtime # Only evaluate foldable values that have non-foldable outputs or are graph outputs. @@ -717,7 +734,9 @@ def should_eval_foldable(tensor): graph_clone.cleanup(remove_unused_graph_inputs=True) # Using ._values avoids a deep copy of the values. - constant_values = {name: tensor._values for name, tensor in graph_constants.items() if isinstance(tensor, Constant)} + constant_values = { + name: tensor._values for name, tensor in graph_constants.items() if isinstance(tensor, Constant) + } if graph_clone.outputs: if partitioning: constant_values.update(partition_and_infer(graph_clone)) @@ -728,15 +747,19 @@ def should_eval_foldable(tensor): values = sess.run(names, {}) constant_values.update({name: val for name, val in zip(names, values)}) except Exception as err: - G_LOGGER.warning("Inference failed. You may want to try enabling partitioning to see better results. " - "Note: Error was:\n{:}".format(err)) + G_LOGGER.warning( + "Inference failed. You may want to try enabling partitioning to see better results. " + "Note: Error was:\n{:}".format(err) + ) G_LOGGER.verbose("Note: Graph was:\n{:}".format(graph_clone)) if not error_ok: raise elif not constant_values: - G_LOGGER.info("Could not find any nodes in this graph ({:}) that can be folded. " - "This could mean that constant folding has already been run on this graph. " - "Skipping.".format(self.name)) + G_LOGGER.info( + "Could not find any nodes in this graph ({:}) that can be folded. " + "This could mean that constant folding has already been run on this graph. " + "Skipping.".format(self.name) + ) # Finally, replace the Variables in the original graph with constants. if constant_values: @@ -745,8 +768,7 @@ def should_eval_foldable(tensor): tensor = graph_tensors[name] if not isinstance(tensor, Constant): tensor.to_constant(values) - tensor.inputs.clear() # Constants do not need inputs - + tensor.inputs.clear() # Constants do not need inputs # Folding subgraphs after the outer graph can lead to better folding. def fold_subgraphs(): @@ -760,13 +782,11 @@ def fold_subgraphs(): return self - def _generate_name(self, prefix): name = "{}_{}".format(prefix, self.name_idx) self.name_idx += 1 return name - def layer(self, inputs=[], outputs=[], *args, **kwargs): """ Creates a node, adds it to this graph, and optionally creates its input and output tensors. @@ -794,6 +814,7 @@ def layer(self, inputs=[], outputs=[], *args, **kwargs): Returns: List[Tensor]: The output tensors of the node """ + def process_io(io): new_io = [] for elem in io: @@ -809,9 +830,11 @@ def process_io(io): arr = np.array(elem, dtype=dtype) new_io.append(Constant(name=self._generate_name("onnx_graphsurgeon_lst_constant"), values=arr)) else: - G_LOGGER.critical("Unrecognized type passed to Graph.layer: {:}.\n" - "\tHint: Did you forget to unpack a list with `*`?\n" - "\tPlease use Tensors, strings, or NumPy arrays.".format(elem)) + G_LOGGER.critical( + "Unrecognized type passed to Graph.layer: {:}.\n" + "\tHint: Did you forget to unpack a list with `*`?\n" + "\tPlease use Tensors, strings, or NumPy arrays.".format(elem) + ) return new_io inputs = process_io(inputs) @@ -824,8 +847,7 @@ def process_io(io): self.nodes.append(node) return node.outputs - - def copy(self, tensor_map: "OrderedDict[str, Tensor]"=None): + def copy(self, tensor_map: "OrderedDict[str, Tensor]" = None): """ Copy the graph. @@ -853,33 +875,37 @@ def copy(self, tensor_map: "OrderedDict[str, Tensor]"=None): # And locally produced tensors should take precedence over everything else. local_tensor_copies.update({n: t.copy() for n, t in self._local_tensors().items()}) - def get_tensor(name): if not name: return Variable.empty() return local_tensor_copies[name] - # Next, copy nodes, and update inputs/outputs new_nodes = [] for node in self.nodes: - new_node = node.copy(inputs=[get_tensor(inp.name) for inp in node.inputs], - outputs=[get_tensor(out.name) for out in node.outputs], - tensor_map=local_tensor_copies) + new_node = node.copy( + inputs=[get_tensor(inp.name) for inp in node.inputs], + outputs=[get_tensor(out.name) for out in node.outputs], + tensor_map=local_tensor_copies, + ) new_nodes.append(new_node) new_graph_inputs = [get_tensor(inp.name) for inp in self.inputs] new_graph_outputs = [get_tensor(out.name) for out in self.outputs] - return Graph(nodes=new_nodes, inputs=new_graph_inputs, outputs=new_graph_outputs, - name=copy.copy(self.name), doc_string=copy.copy(self.doc_string), - opset=copy.copy(self.opset)) - + return Graph( + nodes=new_nodes, + inputs=new_graph_inputs, + outputs=new_graph_outputs, + name=copy.copy(self.name), + doc_string=copy.copy(self.doc_string), + opset=copy.copy(self.opset), + ) def __str__(self): nodes_str = "\n".join([str(node) for node in self.nodes]) return "Graph {:} (Opset: {:})\nInputs: {:}\nNodes:\n{:}\nOutputs: {:}".format( - self.name, self.opset, self.inputs, nodes_str, self.outputs) - + self.name, self.opset, self.inputs, nodes_str, self.outputs + ) def __repr__(self): return self.__str__() diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/node.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/node.py index 0463f980..9712395a 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/node.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/node.py @@ -21,8 +21,16 @@ from collections import OrderedDict from typing import List, Dict + class Node(object): - def __init__(self, op: str, name: str=None, attrs: Dict[str, object]=None, inputs: List["Tensor"]=None, outputs: List["Tensor"]=None): + def __init__( + self, + op: str, + name: str = None, + attrs: Dict[str, object] = None, + inputs: List["Tensor"] = None, + outputs: List["Tensor"] = None, + ): """ A node represents an operation in a graph, and consumes zero or more Tensors, and produces zero or more Tensors. @@ -40,7 +48,6 @@ def __init__(self, op: str, name: str=None, attrs: Dict[str, object]=None, input self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=misc.default_value(inputs, [])) self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=misc.default_value(outputs, [])) - def i(self, tensor_idx=0, producer_idx=0): """ Convenience function to get a producer node of one of this node's input tensors. @@ -61,7 +68,6 @@ def i(self, tensor_idx=0, producer_idx=0): """ return self.inputs[tensor_idx].inputs[producer_idx] - def o(self, consumer_idx=0, tensor_idx=0): """ Convenience function to get a consumer node of one of this node's output tensors. @@ -81,7 +87,6 @@ def o(self, consumer_idx=0, tensor_idx=0): """ return self.outputs[tensor_idx].outputs[consumer_idx] - def __setattr__(self, name, value): if name in ["inputs", "outputs"]: try: @@ -92,8 +97,7 @@ def __setattr__(self, name, value): else: super().__setattr__(name, value) - - def copy(self, inputs: List["Tensor"]=None, outputs: List["Tensor"]=None, tensor_map=None): + def copy(self, inputs: List["Tensor"] = None, outputs: List["Tensor"] = None, tensor_map=None): """ Makes a shallow copy of this node, overriding input and output information. @@ -110,10 +114,9 @@ def copy(self, inputs: List["Tensor"]=None, outputs: List["Tensor"]=None, tensor return Node(self.op, self.name, new_attrs, inputs=inputs, outputs=outputs) - def __str__(self): ret = "{:} ({:})".format(self.name, self.op) - + def add_io(name, io): nonlocal ret ret += "\n\t{:}: [".format(name) @@ -121,7 +124,6 @@ def add_io(name, io): ret += "\n\t\t{:}".format(elem) ret += "\n\t]" - add_io("Inputs", self.inputs) add_io("Outputs", self.outputs) @@ -129,17 +131,19 @@ def add_io(name, io): ret += "\nAttributes: {:}".format(self.attrs) return ret - def __repr__(self): return self.__str__() - def __eq__(self, other): """ Check whether two nodes are equal by comparing name, attributes, op, inputs, and outputs. """ G_LOGGER.verbose("Comparing node: {:} with {:}".format(self.name, other.name)) attrs_match = self.name == other.name and self.op == other.op and self.attrs == other.attrs - inputs_match = len(self.inputs) == len(other.inputs) and all([inp == other_inp for inp, other_inp in zip(self.inputs, other.inputs)]) - outputs_match = len(self.outputs) == len(other.outputs) and all([out == other_out for out, other_out in zip(self.outputs, other.outputs)]) + inputs_match = len(self.inputs) == len(other.inputs) and all( + [inp == other_inp for inp, other_inp in zip(self.inputs, other.inputs)] + ) + outputs_match = len(self.outputs) == len(other.outputs) and all( + [out == other_out for out, other_out in zip(self.outputs, other.outputs)] + ) return attrs_match and inputs_match and outputs_match diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/tensor.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/tensor.py index 25d52df8..4696e3a3 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/tensor.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/tensor.py @@ -23,6 +23,7 @@ class Tensor(object): """Abstract base class for tensors in a graph""" + DYNAMIC = -1 def __init__(self): @@ -31,7 +32,6 @@ def __init__(self): """ raise NotImplementedError("Tensor is an abstract class") - def __setattr__(self, name, value): if name in ["inputs", "outputs"]: try: @@ -42,7 +42,6 @@ def __setattr__(self, name, value): else: super().__setattr__(name, value) - def is_empty(self): """ Returns whether this tensor is considered empty in the graph. @@ -55,8 +54,7 @@ def is_empty(self): """ return self.name == "" - - def to_constant(self, values: np.ndarray, data_location: int=None): + def to_constant(self, values: np.ndarray, data_location: int = None): """ Modifies this tensor in-place to convert it to a Constant. This means that all consumers/producers of the tensor will see the update. @@ -75,8 +73,7 @@ def to_constant(self, values: np.ndarray, data_location: int=None): self.data_location = data_location return self - - def to_variable(self, dtype: np.dtype=None, shape: Sequence[Union[int, str]]=[]): + def to_variable(self, dtype: np.dtype = None, shape: Sequence[Union[int, str]] = []): """ Modifies this tensor in-place to convert it to a Variable. This means that all consumers/producers of the tensor will see the update. @@ -92,7 +89,6 @@ def to_variable(self, dtype: np.dtype=None, shape: Sequence[Union[int, str]]=[]) self.shape = shape return self - def i(self, tensor_idx=0, producer_idx=0): """ Convenience function to get an input tensor of one of this tensor's input nodes. @@ -113,7 +109,6 @@ def i(self, tensor_idx=0, producer_idx=0): """ return self.inputs[producer_idx].inputs[tensor_idx] - def o(self, consumer_idx=0, tensor_idx=0): """ Convenience function to get an output tensor of one of this tensor's output nodes. @@ -133,15 +128,12 @@ def o(self, consumer_idx=0, tensor_idx=0): """ return self.outputs[consumer_idx].outputs[tensor_idx] - def __str__(self): return "{:} ({:}): (shape={:}, dtype={:})".format(type(self).__name__, self.name, self.shape, self.dtype) - - def __repr__(self): # Hack to make logging output pretty. + def __repr__(self): # Hack to make logging output pretty. return self.__str__() - def __eq__(self, other): """ Perform a check to see if two tensors are equal. @@ -156,8 +148,7 @@ class Variable(Tensor): def empty(): return Variable(name="") - - def __init__(self, name: str, dtype: np.dtype=None, shape: Sequence[Union[int, str]]=None): + def __init__(self, name: str, dtype: np.dtype = None, shape: Sequence[Union[int, str]] = None): """ Represents a Tensor whose value is not known until inference-time. @@ -172,13 +163,11 @@ def __init__(self, name: str, dtype: np.dtype=None, shape: Sequence[Union[int, s self.dtype = dtype self.shape = misc.default_value(shape, None) - def to_constant(self, values: np.ndarray): del self.dtype del self.shape return super().to_constant(values) - def copy(self): """ Makes a shallow copy of this tensor, omitting input and output information. @@ -192,17 +181,18 @@ class LazyValues(object): """ A special object that represents constant tensor values that should be lazily loaded. """ + def __init__(self, tensor): """ Args: tensor (onnx.TensorProto): The ONNX tensor that this instance should lazily load. """ from onnx_graphsurgeon.importers.onnx_importer import get_onnx_tensor_shape, get_onnx_tensor_dtype + self.tensor = tensor self.shape = get_onnx_tensor_shape(self.tensor) self.dtype = get_onnx_tensor_dtype(self.tensor) - def load(self): """ Load a numpy array from the underlying tensor values. @@ -212,19 +202,18 @@ def load(self): """ import onnx import onnx.numpy_helper - return np.array(onnx.numpy_helper.to_array(self.tensor)) + return np.array(onnx.numpy_helper.to_array(self.tensor)) def __str__(self): return "LazyValues (shape={:}, dtype={:})".format(self.shape, self.dtype) - - def __repr__(self): # Hack to make logging output pretty. + def __repr__(self): # Hack to make logging output pretty. return self.__str__() class Constant(Tensor): - def __init__(self, name: str, values: Union[np.ndarray, LazyValues], data_location: int=None): + def __init__(self, name: str, values: Union[np.ndarray, LazyValues], data_location: int = None): """ Represents a Tensor whose value is known. @@ -240,18 +229,18 @@ def __init__(self, name: str, values: Union[np.ndarray, LazyValues], data_locati self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=[]) self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=[]) if not isinstance(values, np.ndarray) and not isinstance(values, LazyValues): - G_LOGGER.critical("Provided `values` argument is not a NumPy array or a LazyValues instance. " - "Please provide a NumPy array or LazyValues instance to construct a Constant. " - "Note: Provided `values` parameter was: {:}".format(values)) + G_LOGGER.critical( + "Provided `values` argument is not a NumPy array or a LazyValues instance. " + "Please provide a NumPy array or LazyValues instance to construct a Constant. " + "Note: Provided `values` parameter was: {:}".format(values) + ) self._values = values self.data_location = data_location - - def to_variable(self, dtype: np.dtype=None, shape: Sequence[Union[int, str]]=[]): + def to_variable(self, dtype: np.dtype = None, shape: Sequence[Union[int, str]] = []): del self._values return super().to_variable(dtype, shape) - def copy(self): """ Makes a shallow copy of this tensor, omitting input and output information. @@ -260,7 +249,6 @@ def copy(self): """ return Constant(self.name, self._values) - @property def values(self): # Load values when they are first accesed @@ -268,23 +256,19 @@ def values(self): self._values = self._values.load() return self._values - @values.setter def values(self, values: Union[np.ndarray, LazyValues]): self._values = values - @property def shape(self): return self._values.shape - @property def dtype(self): return self._values.dtype.type - - def __repr__(self): # Hack to make logging output pretty. + def __repr__(self): # Hack to make logging output pretty. ret = self.__str__() ret += "\n{:}".format(self._values) return ret diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/logger/logger.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/logger/logger.py index 444d4ba1..2c5b5f6a 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/logger/logger.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/logger/logger.py @@ -53,8 +53,8 @@ def __exit__(self, exc_type, exc_value, traceback): class LogMode(enum.IntEnum): - EACH = 0 # Log the message each time - ONCE = 1 # Log the message only once. The same message will not be logged again. + EACH = 0 # Log the message each time + ONCE = 1 # Log the message only once. The same message will not be logged again. class Logger(object): @@ -99,7 +99,7 @@ def __init__(self, severity=INFO, colors=True, letter=True, timestamp=False, lin """ self._severity = severity self.logging_indent = 0 - self.root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) + self.root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) self.once_logged = set() self.colors = colors self.letter = letter @@ -107,19 +107,16 @@ def __init__(self, severity=INFO, colors=True, letter=True, timestamp=False, lin self.line_info = line_info self.logger_callbacks = [] - @property def severity(self): return self._severity - @severity.setter def severity(self, value): self._severity = value for callback in self.logger_callbacks: callback(self._severity) - def register_callback(self, callback): """ Registers a callback with the logger, which will be invoked when the logging severity is modified. @@ -131,14 +128,12 @@ def register_callback(self, callback): callback(self._severity) self.logger_callbacks.append(callback) - def indent(self, level=1): """ Returns a context manager that indents all strings logged by the specified amount. """ return LoggerIndent(self, level + self.logging_indent) - def suppress(self, severity=CRITICAL): """ Returns a context manager that temporarily changes the severity of the logger for its duration. @@ -148,8 +143,6 @@ def suppress(self, severity=CRITICAL): """ return LoggerSuppress(self, severity) - - # If once is True, the logger will only log this message a single time. Useful in loops. # message may be a callable which returns a message. This way, only if the message needs to be logged is it ever generated. def log(self, message, severity, mode=LogMode.EACH, stack_depth=2): @@ -184,20 +177,21 @@ def apply_color(message): if self.colors: try: import colored + color = Logger.SEVERITY_COLOR_MAPPING[severity] return colored.stylize(message, [colored.fg(color)]) except (ImportError, ModuleNotFoundError): self.colors = False - self.warning("colored module is not installed, will not use colors when logging. To enable colors, please install the colored module: python3 -m pip install colored") + self.warning( + "colored module is not installed, will not use colors when logging. To enable colors, please install the colored module: python3 -m pip install colored" + ) self.colors = True return message - prefix = get_prefix() message = apply_indentation(message) return apply_color("{:}{:}".format(prefix, message)) - def should_log(message): should = severity >= self._severity if mode == LogMode.ONCE: @@ -206,7 +200,6 @@ def should_log(message): self.once_logged.add(message_hash) return should - if not should_log(message): return @@ -215,35 +208,28 @@ def should_log(message): message = str(message) print(process_message(message, stack_depth=stack_depth)) - def ultra_verbose(self, message, mode=LogMode.EACH): self.log(message, Logger.ULTRA_VERBOSE, mode=mode, stack_depth=3) - def verbose(self, message, mode=LogMode.EACH): self.log(message, Logger.VERBOSE, mode=mode, stack_depth=3) - def debug(self, message, mode=LogMode.EACH): self.log(message, Logger.DEBUG, mode=mode, stack_depth=3) - def info(self, message, mode=LogMode.EACH): self.log(message, Logger.INFO, mode=mode, stack_depth=3) - def warning(self, message, mode=LogMode.EACH): self.log(message, Logger.WARNING, mode=mode, stack_depth=3) - def error(self, message, mode=LogMode.EACH): self.log(message, Logger.ERROR, mode=mode, stack_depth=3) - # Like error, but immediately exits. def critical(self, message): self.log(message, Logger.CRITICAL, stack_depth=3) - raise OnnxGraphSurgeonException(message) from None # Erase exception chain + raise OnnxGraphSurgeonException(message) from None # Erase exception chain global G_LOGGER diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/util/exception.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/util/exception.py index 2d4eed17..36853e68 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/util/exception.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/util/exception.py @@ -14,8 +14,10 @@ # limitations under the License. # + class OnnxGraphSurgeonException(Exception): """ An exception raised by ONNX-GraphSurgeon. """ + pass diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/util/misc.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/util/misc.py index d724bad8..b660b3f7 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/util/misc.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/util/misc.py @@ -80,33 +80,27 @@ def __init__(self, parent_obj, field_name, initial): self.field_name = field_name self.extend(initial) - def _add_to_elem(self, elem): # Explicitly avoid SynchronizedList overrides to prevent infinite recursion list.append(getattr(elem, self.field_name), self.parent_obj) - def _remove_from_elem(self, elem): # Explicitly avoid SynchronizedList overrides to prevent infinite recursion list.remove(getattr(elem, self.field_name), self.parent_obj) - def __delitem__(self, index): self._remove_from_elem(self[index]) super().__delitem__(index) - def __setitem__(self, index, elem): self._remove_from_elem(self[index]) super().__setitem__(index, elem) self._add_to_elem(elem) - def append(self, x): super().append(x) self._add_to_elem(x) - def extend(self, iterable: Sequence[object]): super().extend(iterable) for elem in iterable: @@ -116,28 +110,23 @@ def insert(self, i, x): super().insert(i, x) self._add_to_elem(x) - def remove(self, x): super().remove(x) self._remove_from_elem(x) - def pop(self, i=-1): elem = super().pop(i) self._remove_from_elem(elem) return elem - def clear(self): for elem in self: self._remove_from_elem(elem) super().clear() - def __add__(self, other_list: List[object]): return list(self) + list(other_list) - def __iadd__(self, other_list: List[object]): self.extend(other_list) return self diff --git a/tools/onnx-graphsurgeon/setup.py b/tools/onnx-graphsurgeon/setup.py index 99b5e245..682dfbb5 100644 --- a/tools/onnx-graphsurgeon/setup.py +++ b/tools/onnx-graphsurgeon/setup.py @@ -18,11 +18,12 @@ import onnx_graphsurgeon from setuptools import setup, find_packages + def no_publish(): - blacklist = ['register'] + blacklist = ["register"] for cmd in blacklist: if cmd in sys.argv: - raise RuntimeError("Command \"{}\" blacklisted".format(cmd)) + raise RuntimeError('Command "{}" blacklisted'.format(cmd)) REQUIRED_PACKAGES = [ @@ -30,6 +31,7 @@ def no_publish(): "onnx", ] + def main(): no_publish() setup( @@ -42,8 +44,8 @@ def main(): author="NVIDIA", author_email="svc_tensorrt@nvidia.com", classifiers=[ - 'Intended Audience :: Developers', - 'Programming Language :: Python :: 3', + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.4", "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", @@ -55,5 +57,6 @@ def main(): zip_safe=True, ) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/tools/onnx-graphsurgeon/tests/ir/test_graph.py b/tools/onnx-graphsurgeon/tests/ir/test_graph.py index ec467bdc..e69681c7 100644 --- a/tools/onnx-graphsurgeon/tests/ir/test_graph.py +++ b/tools/onnx-graphsurgeon/tests/ir/test_graph.py @@ -37,7 +37,9 @@ def shape(self, inp): @Graph.register() def constant(self, values): - return self.layer(op="Constant", inputs=[], outputs=["constant_out"], attrs={"value": Constant("values", values)})[0] + return self.layer(op="Constant", inputs=[], outputs=["constant_out"], attrs={"value": Constant("values", values)})[ + 0 + ] @Graph.register() @@ -71,8 +73,14 @@ def gather(self, data, indices): @gs.Graph.register() -def slice(self, data, starts, ends, axes, steps): - return self.layer(op="Slice", inputs=[data, starts, ends, axes, steps], outputs=["slice_out"])[0] +def slice(self, data, starts=None, ends=None, axes=None, steps=None): + inputs = [] + for inp in [data, starts, ends, axes, steps]: + if inp is None: + break + inputs.append(inp) + + return self.layer(op="Slice", inputs=inputs, outputs=["slice_out"])[0] @gs.Graph.register() @@ -82,9 +90,9 @@ def nested(self, inp, graph): @gs.Graph.register() def if_op(self, cond, then_graph, else_graph): - return self.layer(op="If", inputs=[cond], outputs=["if_out"], - attrs={"then_branch": then_graph, "else_branch": else_graph})[0] - + return self.layer( + op="If", inputs=[cond], outputs=["if_out"], attrs={"then_branch": then_graph, "else_branch": else_graph} + )[0] # Generates a graph where an outer node has no outputs except @@ -105,8 +113,7 @@ def nested_graph(): subgraph_identity0 = Node(op="Identity", inputs=[id_out], outputs=[subgraph_id_out]) subgraph_identity1 = Node(op="Identity", inputs=[subgraph_id_out], outputs=subgraph_outputs) - subgraph = Graph(nodes=[subgraph_identity0, subgraph_identity1], - inputs=subgraph_inputs, outputs=subgraph_outputs) + subgraph = Graph(nodes=[subgraph_identity0, subgraph_identity1], inputs=subgraph_inputs, outputs=subgraph_outputs) nested_out = Variable("nested_out") nested_node = Node(op="Nested", attrs={"body": subgraph}, inputs=[inp], outputs=[nested_out]) @@ -137,7 +144,6 @@ def fake_add(self, a, b): assert len(graph.nodes) == 1 assert graph.nodes[-1].op == "Add" - def test_register_opset(self): @Graph.register(opsets=[11]) def fake_add(self, a, b): @@ -169,7 +175,6 @@ def test_layer_with_attrs(self): assert graph.nodes[-1].name == "node" assert graph.nodes[-1].attrs["fake_attr"] == 0 - def test_layer_with_tensors(self): x0 = Variable("x0") x1 = Variable("x1") @@ -183,7 +188,6 @@ def test_layer_with_tensors(self): assert graph.nodes[-1].inputs == [x0, x1] assert graph.nodes[-1].outputs == outputs - def test_layer_with_strings(self): x0 = "x0" x1 = "x1" @@ -197,7 +201,6 @@ def test_layer_with_strings(self): assert [prefix in tensor.name for prefix, tensor in zip([y0, y1], graph.nodes[-1].outputs)] assert graph.nodes[-1].outputs == outputs - def test_layer_with_arrays(self): x0 = np.array([1]) x1 = np.array([1]) @@ -212,10 +215,9 @@ def test_layer_with_arrays(self): assert graph.nodes[-1].inputs[1].values == x1 assert graph.nodes[-1].outputs == outputs - def test_layer_with_iterables(self): x0 = [1] - x1 = (1, ) + x1 = (1,) y0 = "y0" y1 = "y1" graph = Graph() @@ -228,7 +230,6 @@ def test_layer_with_iterables(self): assert graph.nodes[-1].outputs == outputs - def tensors_linear_graph(): inputs = [Variable(name="x")] intermediate0 = Variable(name="intermediate0") @@ -248,7 +249,6 @@ def tensors_linear_graph(): return Graph(nodes=nodes, inputs=inputs, outputs=outputs), nodes, tensors - class TestTensors(object): # Calling `tensors()` should not modify tensors in the graph. def test_tensors_does_not_modify_tensors(self): @@ -265,7 +265,6 @@ def test_tensors_does_not_modify_tensors(self): assert tensor.inputs == graph_tensor.inputs assert tensor.outputs == graph_tensor.outputs - # Check that tensors includes tensors not attached to nodes def test_tensors_includes_non_node_tensors(self): X = Constant("X", values=np.ones(shape=(64, 64), dtype=np.float32)) @@ -274,10 +273,9 @@ def test_tensors_includes_non_node_tensors(self): assert "X" in tensor_map assert tensor_map["X"] == X - def test_tensors_check_duplicates(self): inputs = [Variable(name="x")] - outputs = [Variable(name="x")] # Distinct tensors with the same name + outputs = [Variable(name="x")] # Distinct tensors with the same name nodes = [ Node(op="Add", name="Test", inputs=inputs, outputs=outputs), ] @@ -361,6 +359,7 @@ def toposort_multi_tier_input_graph(): toposort_multi_tier_input_graph, ] + class TestToposort(object): @pytest.mark.parametrize("toposort_test_case", TOPOSORT_TEST_CASES) def test_topologically_sort(self, toposort_test_case): @@ -369,7 +368,6 @@ def test_topologically_sort(self, toposort_test_case): graph.toposort() assert graph.nodes == expected_node_order - @pytest.mark.parametrize("toposort_test_case", TOPOSORT_TEST_CASES) def test_toposort_nested(self, toposort_test_case): subgraph, expected_node_order = toposort_test_case() @@ -449,12 +447,11 @@ def test_get_used_node_ids(self, graph): assert unused_tensor not in used_tensors assert all([used_tensor in used_tensors for used_tensor in graph_used_tensors]) - def test_multi_tier(self): graph, _ = toposort_multi_tier_output_graph() tensor = graph.outputs.pop() unused_node = tensor.inputs[0] - graph.cleanup() # Should remove just the Test2 node as out1 is still an output. + graph.cleanup() # Should remove just the Test2 node as out1 is still an output. assert unused_node not in graph.nodes assert len(graph.nodes) == 2 assert len(graph.outputs) == 2 @@ -462,9 +459,8 @@ def test_multi_tier(self): tensor_map = graph.tensors() assert tensor.name not in tensor_map - def test_remove_unused_node_outputs(self): - graph, _ = toposort_linear_graph() + graph, _ = toposort_linear_graph() graph.toposort() graph_output = graph.outputs[0] @@ -475,8 +471,7 @@ def test_remove_unused_node_outputs(self): graph.cleanup(remove_unused_node_outputs=True) assert dummy not in graph.nodes[1].outputs - assert graph.outputs[0] == graph_output # Graoh outputs will never be removed - + assert graph.outputs[0] == graph_output # Graoh outputs will never be removed def test_graph_input_producers(self): graph, _ = toposort_linear_graph() @@ -489,7 +484,6 @@ def test_graph_input_producers(self): cleaned_tensor_map = graph.tensors() assert "x" not in cleaned_tensor_map - @pytest.mark.parametrize("remove_unused_graph_inputs", [True, False]) def test_independent_path(self, remove_unused_graph_inputs): graph, _ = toposort_linear_graph() @@ -507,7 +501,6 @@ def test_independent_path(self, remove_unused_graph_inputs): assert indep0.name not in tensor_map or not remove_unused_graph_inputs assert indep1.name not in tensor_map or not remove_unused_graph_inputs - def test_nested_graph(self, nested_graph): nested_node = nested_graph.nodes[1] nested_inp = nested_node.inputs[0] @@ -529,15 +522,14 @@ def test_nested_graph(self, nested_graph): nested_graph.cleanup(recurse_subgraphs=True) assert not subgraph.nodes - def test_node_used_only_in_nested_graph(self): - X = Variable("X", dtype=np.float32, shape=(1, )) - Y = Variable("Y", dtype=np.float32, shape=(1, )) + X = Variable("X", dtype=np.float32, shape=(1,)) + Y = Variable("Y", dtype=np.float32, shape=(1,)) graph = Graph(inputs=[X, Y]) - X_p = graph.identity(X) # X_p is only used by the subgraph, not in the outer graph. + X_p = graph.identity(X) # X_p is only used by the subgraph, not in the outer graph. - subgraph_inp = Variable("subgraph_input", dtype=np.float32, shape=(1, )) + subgraph_inp = Variable("subgraph_input", dtype=np.float32, shape=(1,)) subgraph = Graph(inputs=[subgraph_inp]) subgraph.outputs = [subgraph.add(subgraph_inp, X_p)] @@ -548,7 +540,6 @@ def test_node_used_only_in_nested_graph(self): assert graph.nodes[0].op == "Identity" assert graph.nodes[0].inputs == [X] - def test_input_is_output(self): graph = Graph() @@ -558,7 +549,7 @@ def test_input_is_output(self): C = graph.add(A, B) graph.inputs = [A, B] - graph.outputs = [C, B, A] # Out of order w/ respect to Add node inputs + graph.outputs = [C, B, A] # Out of order w/ respect to Add node inputs # Graph should remain unchanged after cleanup, including I/O tensors. graph.cleanup() @@ -589,7 +580,6 @@ def make_graph(): assert graph != new_graph assert new_graph == make_graph() - def test_copy_with_subgraph(self, nested_graph): new_graph = nested_graph.copy() assert new_graph == nested_graph @@ -616,7 +606,6 @@ def test_copy_with_subgraph(self, nested_graph): assert len(nested_graph.nodes) == 2 assert len(subgraph.nodes) == 2 - # If the subgraph has a tensor with the same name as the outer graph, # the subgraph copy should include a copy of the subgraph tensor, not the outer # graph tensor. @@ -633,7 +622,6 @@ def test_copy_with_subgraph_dup_tensors(self): graph_copy = graph.copy() assert graph_copy.nodes[0].attrs["body"].inputs[0].shape == (1, 2) - def test_copy_with_subgraph_dup_const_tensors(self): inp = Constant("input", values=np.ones(dtype=np.float32, shape=(4, 5))) graph = Graph() @@ -728,7 +716,6 @@ def test_basic(self, simple_foldable, partitioning): # Value should be computed correctly assert np.all(simple_foldable.nodes[0].inputs[1].values == np.ones(shape=(1, 3), dtype=np.float32) * 2) - def test_one_hop(self, one_hop_foldable): inp = one_hop_foldable.inputs[0] @@ -742,7 +729,6 @@ def test_one_hop(self, one_hop_foldable): # Value should be computed correctly assert np.all(one_hop_foldable.nodes[0].inputs[1].values == np.ones(shape=(1, 3), dtype=np.float32) * 3) - def test_with_invalid_nodes(self, foldable_with_invalid_node): foldable_with_invalid_node.fold_constants(partitioning="recursive").cleanup() @@ -754,21 +740,17 @@ def test_with_invalid_nodes(self, foldable_with_invalid_node): assert foldable_with_invalid_node.nodes[2].op == "Add" assert np.all(tensor_map["c"].values == (np.ones(shape=(1, 3), dtype=np.float32) * 2)) - def test_with_invalid_nodes_no_recursive(self, foldable_with_invalid_node): # No folding should take place without recursive partitioning original = foldable_with_invalid_node.copy() assert foldable_with_invalid_node.fold_constants() == original - def test_no_foldable_constants(self): inp0 = Variable("input0", shape=(1, 3), dtype=np.float32) inp1 = Variable("input1", shape=(1, 3), dtype=np.float32) out = Variable("output", shape=(1, 3), dtype=np.float32) - nodes = [ - Node("Add", inputs=[inp0, inp1], outputs=[out]) - ] + nodes = [Node("Add", inputs=[inp0, inp1], outputs=[out])] graph = Graph(nodes=nodes, inputs=[inp0, inp1], outputs=[out]) @@ -777,7 +759,6 @@ def test_no_foldable_constants(self): assert len(graph.nodes) == 1 assert graph.nodes[0].inputs == [inp0, inp1] - def test_const_node(self): graph = Graph() values = np.ones((1, 3, 3), dtype=np.int64) @@ -791,7 +772,6 @@ def test_const_node(self): assert np.all(graph.outputs[0].values == values) assert not graph.nodes - def test_shape_of_constant_tensor(self): graph = Graph() values = np.ones((1, 3, 3), dtype=np.int64) @@ -804,7 +784,6 @@ def test_shape_of_constant_tensor(self): assert isinstance(graph.outputs[0], Constant) assert np.all(graph.outputs[0].values == (1, 3, 3)) - def test_shape_of_constant_node(self): graph = Graph() values = np.ones((1, 3, 3), dtype=np.int64) @@ -817,7 +796,6 @@ def test_shape_of_constant_node(self): assert isinstance(graph.outputs[0], Constant) assert np.all(graph.outputs[0].values == (1, 3, 3)) - # Cannot fold shape nodes if they have dynamically shaped inputs. def test_shape_of_variable_tensor_dynamic_shape(self): var = Variable("var", dtype=np.float32, shape=("", -1, 0, 4)) @@ -830,7 +808,6 @@ def test_shape_of_variable_tensor_dynamic_shape(self): assert graph.nodes[0].op == "Shape" assert isinstance(graph.outputs[0], Variable) - def test_shape_of_variable_tensor_static_shape(self): var = Variable("var", dtype=np.float32, shape=(1, 3, 4)) graph = Graph(inputs=[var]) @@ -843,11 +820,10 @@ def test_shape_of_variable_tensor_static_shape(self): assert isinstance(graph.outputs[0], Constant) assert np.all(graph.outputs[0].values == (1, 3, 4)) - def test_shape_of_variable_tensor_multiple_shapes(self): graph = Graph() var = Variable("var", dtype=np.float32, shape=(1, 3, 4)) - var2 = Variable("var2", dtype=np.float32, shape=tuple()) # Scalar + var2 = Variable("var2", dtype=np.float32, shape=tuple()) # Scalar graph.inputs = [var, var2] graph.outputs = [graph.shape(var), graph.identity(var), graph.shape(var2)] @@ -860,7 +836,6 @@ def test_shape_of_variable_tensor_multiple_shapes(self): assert isinstance(graph.outputs[2], Constant) assert np.all(graph.outputs[2].values == tuple()) - def test_shape_of_variable_tensor_static_shape_no_fold(self): graph = Graph() var = Variable("var", dtype=np.float32, shape=(1, 3, 4)) @@ -873,7 +848,6 @@ def test_shape_of_variable_tensor_static_shape_no_fold(self): assert graph.nodes[0].op == "Shape" assert isinstance(graph.outputs[0], Variable) - # Constant folding should not cause constant tensors in the model to be loaded. def test_no_load_constants(self): graph = gs.import_onnx(const_foldable().load()) @@ -885,19 +859,21 @@ def check_no_const_loaded(graph): for tensor in graph.tensors().values(): if isinstance(tensor, Constant) and isinstance(tensor._values, LazyValues): num_lazy_constants += 1 - assert num_lazy_constants == 3 # Graph starts with 3 constants - none should be loaded. + assert num_lazy_constants == 3 # Graph starts with 3 constants - none should be loaded. check_no_const_loaded(graph) check_no_const_loaded(new_graph) - - @pytest.mark.parametrize("shape, indices", [ - (("batch", 3, "height", "width"), 1), # Scalar indices case - (None, 1), # Shape not inferered case - (("batch", 3, "height", "width"), [1]), - (("batch", 3, "height", 224), [1, 3]), - (("batch", 3, 224, 224), [1, 2, 3]), - ]) + @pytest.mark.parametrize( + "shape, indices", + [ + (("batch", 3, "height", "width"), 1), # Scalar indices case + (None, 1), # Shape not inferered case + (("batch", 3, "height", "width"), [1]), + (("batch", 3, "height", 224), [1, 3]), + (("batch", 3, 224, 224), [1, 2, 3]), + ], + ) def test_shape_gather(self, shape, indices): indices = np.array(indices) @@ -924,24 +900,28 @@ def test_shape_gather(self, shape, indices): assert isinstance(graph.outputs[1], Variable) assert isinstance(graph.outputs[2], Variable) - - @pytest.mark.parametrize("shape, starts, ends, axes, steps, expected", [ - (("batch", 3, "height", "width"), 1, 2, 0, 1, [3]), # Scalar starts/ends case - (("batch", 3, "height", "width"), [1], [2], [0], [1], [3]), - (("batch", 3, 5, "width"), [1], [-1], [0], [1], [3, 5]), # Negative ends case - (("batch", 3, 5, 7), [1], [2000], [0], [1], [3, 5, 7]), # Past end, ends case - (("batch", 3, 5, 7), [-2], [4], [0], [1], [5, 7]), # Negative starts case - (("batch", 3, 5, 7), [-2], [4], [1], [1], None), # Non-zero axes case - (("batch", 3, 5, "width"), [-2], [4], [1], [1], None), # Dynamic case - (("batch", 3, 5, 7), [1], [4], [0], [2], [3, 7]), # Non-one steps case - (("batch", 3, 5, 7), [4], [0], [0], [-1], [7, 5, 3]), # Negative steps case - ]) + @pytest.mark.parametrize( + "shape, starts, ends, axes, steps, expected", + [ + (("batch", 3, "height", "width"), 1, 2, 0, 1, [3]), # Scalar starts/ends case + (("batch", 3, "height", "width"), [1], [2], [0], [1], [3]), + (("batch", 3, 5, "width"), [1], [-1], [0], [1], [3, 5]), # Negative ends case + (("batch", 3, 5, 7), [1], [2000], [0], [1], [3, 5, 7]), # Past end, ends case + (("batch", 3, 5, 7), [-2], [4], [0], [1], [5, 7]), # Negative starts case + (("batch", 3, 5, 7), [-2], [4], [1], [1], None), # Non-zero axes case + (("batch", 3, 5, "width"), [-2], [4], [1], [1], None), # Dynamic case + (("batch", 3, 5, 7), [1], [4], [0], [2], [3, 7]), # Non-one steps case + (("batch", 3, 5, 7), [4], [0], [0], [-1], [7, 5, 3]), # Negative steps case + ], + ) def test_shape_slice(self, shape, starts, ends, axes, steps, expected): inp = Variable("input", dtype=np.float32, shape=shape) graph = Graph(inputs=[inp]) inp_shape = graph.shape(inp) - graph.outputs = [graph.slice(inp_shape, np.array(starts), np.array(ends), axes=np.array(axes), steps=np.array(steps))] + graph.outputs = [ + graph.slice(inp_shape, np.array(starts), np.array(ends), axes=np.array(axes), steps=np.array(steps)) + ] graph.fold_constants() @@ -951,12 +931,33 @@ def test_shape_slice(self, shape, starts, ends, axes, steps, expected): else: assert isinstance(graph.outputs[0], Variable) + # In the single input case, we should derive starts/ends/axes/steps from the attributes. + def test_shape_slice_single_input(self): + inp = Variable("input", dtype=np.int64, shape=(5, 6, 3, 2)) + graph = Graph(inputs=[inp]) + + inp_shape = graph.shape(inp) + graph.outputs = [graph.slice(inp_shape)] + + slice_node = graph.outputs[0].inputs[0] + + slice_node.attrs = { + "axes": [0], + "starts": [1], + "ends": [3], + "steps": [2], + } + + graph.fold_constants() + + assert isinstance(graph.outputs[0], Constant) + assert np.all(graph.outputs[0].values == inp.shape[1:3:2]) def test_with_nested_graph(self): - cond = gs.Variable("cond", dtype=np.bool, shape=(1, )) + cond = gs.Variable("cond", dtype=np.bool, shape=(1,)) - X = gs.Variable("X", dtype=np.float32, shape=(1, )) - Y = gs.Constant("Y", values=np.ones((1, ), dtype=np.float32)) + X = gs.Variable("X", dtype=np.float32, shape=(1,)) + Y = gs.Constant("Y", values=np.ones((1,), dtype=np.float32)) graph = Graph(inputs=[X, cond]) then_graph = Graph(name="Then") @@ -977,10 +978,9 @@ def test_with_nested_graph(self): assert isinstance(else_graph.nodes[0].inputs[1], Constant) assert np.all(else_graph.nodes[0].inputs[1].values == (Y.values * 2)) - def test_const_inp_but_non_foldable_nested_graph(self): cond = gs.Constant("cond", values=np.array(True)) - X = gs.Variable("X", dtype=np.float32, shape=(1, )) + X = gs.Variable("X", dtype=np.float32, shape=(1,)) graph = Graph(inputs=[X]) @@ -1016,7 +1016,6 @@ def test_io_cannot_be_sync_list_on_init(self): assert not isinstance(graph.inputs, SynchronizedList) assert not isinstance(graph.outputs, SynchronizedList) - def test_io_cannot_be_sync_list_on_assign(self): inp = Variable("input0", shape=(1, 3), dtype=np.float32) out = Variable("input1", shape=(1, 3), dtype=np.float32) diff --git a/tools/onnx-graphsurgeon/tests/onnx_models.py b/tools/onnx-graphsurgeon/tests/onnx_models.py index 9f289b61..f84f46bb 100644 --- a/tools/onnx-graphsurgeon/tests/onnx_models.py +++ b/tools/onnx-graphsurgeon/tests/onnx_models.py @@ -31,6 +31,7 @@ TEST_ROOT = os.path.realpath(os.path.dirname(__file__)) + class Model(object): def __init__(self, path: str, inputs: List[Tensor], outputs: List[Tensor], nodes: List[Node], opset: int): self.path = path @@ -48,6 +49,7 @@ def assert_equal(self, graph: Graph): # Break down fields to make debugging failures easier. for actual, expected in zip(graph.nodes, self.nodes): + def check_tensor_io(actensor, extensor): def check_list(aclist, exlist): G_LOGGER.debug("Actual node list: {:}\n\nExpected node list: {:}".format(aclist, exlist)) @@ -60,7 +62,6 @@ def check_list(aclist, exlist): G_LOGGER.debug("Checking tensor: {:} outputs".format(actensor.name)) check_list(actensor.outputs, extensor.outputs) - G_LOGGER.debug("Actual Node: {:}\n\nExpected Node: {:}".format(actual, expected)) assert actual.op == expected.op assert actual.inputs == expected.inputs @@ -84,7 +85,6 @@ def check_list(aclist, exlist): assert graph.outputs == self.outputs G_LOGGER.debug("Graph outputs matched") - def __str__(self): return os.path.basename(self.path) @@ -133,7 +133,12 @@ def load_initializer(index: int) -> np.ndarray: attrs = OrderedDict() attrs["direction"] = "forward" attrs["hidden_size"] = 5 - node = Node(op="LSTM", attrs=attrs, inputs=[X, W, R, B, Variable.empty(), Variable.empty(), initial_c], outputs=[Y, Y_h, Y_c]) + node = Node( + op="LSTM", + attrs=attrs, + inputs=[X, W, R, B, Variable.empty(), Variable.empty(), initial_c], + outputs=[Y, Y_h, Y_c], + ) # Initializers will not be included in the graph inputs. return Model(path, inputs=[X], outputs=[Y, Y_h, Y_c], nodes=[node], opset=OnnxImporter.get_opset(model)) @@ -144,10 +149,10 @@ def scan_model(): model = onnx.load(path) # Body graph - sum_in = Variable(name="sum_in", dtype=np.float32, shape=(2, )) - next = Variable(name="next", dtype=np.float32, shape=(2, )) - sum_out = Variable(name="sum_out", dtype=np.float32, shape=(2, )) - scan_out = Variable(name="scan_out", dtype=np.float32, shape=(2, )) + sum_in = Variable(name="sum_in", dtype=np.float32, shape=(2,)) + next = Variable(name="next", dtype=np.float32, shape=(2,)) + sum_out = Variable(name="sum_out", dtype=np.float32, shape=(2,)) + scan_out = Variable(name="scan_out", dtype=np.float32, shape=(2,)) body_nodes = [ Node(op="Add", inputs=[sum_in, next], outputs=[sum_out]), @@ -157,11 +162,11 @@ def scan_model(): # Outer graph inputs = [ - Variable(name="initial", dtype=np.float32, shape=(2, )), + Variable(name="initial", dtype=np.float32, shape=(2,)), Variable(name="x", dtype=np.float32, shape=(3, 2)), ] outputs = [ - Variable(name="y", dtype=np.float32, shape=(2, )), + Variable(name="y", dtype=np.float32, shape=(2,)), Variable(name="z", dtype=np.float32, shape=(3, 2)), ] @@ -226,4 +231,4 @@ def ext_weights(): def const_foldable(): path = os.path.join(TEST_ROOT, "models", "const_foldable.onnx") - return Model(path, inputs=None, outputs=None, nodes=None, opset=None) # Only used for path. + return Model(path, inputs=None, outputs=None, nodes=None, opset=None) # Only used for path. diff --git a/tools/onnx-graphsurgeon/tests/test_api.py b/tools/onnx-graphsurgeon/tests/test_api.py index d1cc830a..2fb67aef 100644 --- a/tools/onnx-graphsurgeon/tests/test_api.py +++ b/tools/onnx-graphsurgeon/tests/test_api.py @@ -22,16 +22,15 @@ import tempfile import onnx + class TestApi(object): def setup_method(self): self.imported_graph = OnnxImporter.import_graph(identity_model().load().graph) - def test_import(self): graph = gs.import_onnx(onnx.load(identity_model().path)) assert graph == self.imported_graph - def test_export(self): with tempfile.NamedTemporaryFile() as f: onnx_model = gs.export_onnx(self.imported_graph) diff --git a/tools/onnx-graphsurgeon/tests/test_examples.py b/tools/onnx-graphsurgeon/tests/test_examples.py index caecb1ab..ff5bc18f 100644 --- a/tools/onnx-graphsurgeon/tests/test_examples.py +++ b/tools/onnx-graphsurgeon/tests/test_examples.py @@ -30,6 +30,7 @@ ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir)) EXAMPLES_ROOT = os.path.join(ROOT_DIR, "examples") + class Artifact(object): def __init__(self, name, infer=True): self.name = name @@ -54,7 +55,7 @@ def ignore_command(cmd): return "pip" in cmd commands = [] - with open(readme, 'r') as f: + with open(readme, "r") as f: in_command_block = False for line in f.readlines(): if not in_command_block and "```bash" in line: diff --git a/tools/onnx-graphsurgeon/tests/test_exporters.py b/tools/onnx-graphsurgeon/tests/test_exporters.py index 7cdce7f0..09db8d21 100644 --- a/tools/onnx-graphsurgeon/tests/test_exporters.py +++ b/tools/onnx-graphsurgeon/tests/test_exporters.py @@ -25,9 +25,15 @@ from onnx_graphsurgeon.ir.node import Node from onnx_graphsurgeon.ir.tensor import Constant, LazyValues, Tensor, Variable -from onnx_models import (dim_param_model, ext_weights, identity_model, - initializer_is_output_model, lstm_model, - nested_dup_names, scan_model) +from onnx_models import ( + dim_param_model, + ext_weights, + identity_model, + initializer_is_output_model, + lstm_model, + nested_dup_names, + scan_model, +) class TestOnnxExporter(object): @@ -42,7 +48,6 @@ def test_export_constant_tensor_lazy_values_to_tensor_proto(self): onnx_tensor = OnnxExporter.export_tensor_proto(tensor) assert isinstance(tensor._values, LazyValues) - def test_export_constant_tensor_to_tensor_proto(self): name = "constant_tensor" shape = (3, 224, 224) @@ -55,7 +60,6 @@ def test_export_constant_tensor_to_tensor_proto(self): assert onnx_tensor.data_type == onnx.TensorProto.FLOAT assert tuple(onnx_tensor.dims) == shape - def test_export_constant_tensor_to_value_info_proto(self): name = "constant_tensor" shape = (3, 224, 224) @@ -71,7 +75,6 @@ def test_export_constant_tensor_to_value_info_proto(self): onnx_shape.append(dim.dim_value) assert tuple(onnx_shape) == shape - def test_export_variable_tensor(self): name = "variable_tensor" shape = (3, 224, 224) @@ -87,7 +90,6 @@ def test_export_variable_tensor(self): onnx_shape.append(dim.dim_value) assert tuple(onnx_shape) == shape - def test_export_variable_tensor_empty_dim_param(self): shape = ("", 224, 224) @@ -99,7 +101,6 @@ def test_export_variable_tensor_empty_dim_param(self): onnx_shape.append(dim.dim_value if dim.HasField("dim_value") else dim.dim_param) assert tuple(onnx_shape) == shape - # When a tensor shape is unknown, we should leave the shape field empty. def test_export_variable_tensor_empty_shape(self): shape = None @@ -108,7 +109,6 @@ def test_export_variable_tensor_empty_shape(self): onnx_tensor = OnnxExporter.export_value_info_proto(tensor, do_type_check=True) assert not onnx_tensor.type.tensor_type.HasField("shape") - # When a tensor shape is unknown, we should leave the shape field empty. def test_export_variable_tensor_scalar_shape(self): shape = [None] @@ -118,7 +118,6 @@ def test_export_variable_tensor_scalar_shape(self): assert not onnx_tensor.type.tensor_type.shape.dim[0].HasField("dim_param") assert not onnx_tensor.type.tensor_type.shape.dim[0].HasField("dim_value") - # TODO: Test subgraph export. def test_export_node(self): name = "TestNode" @@ -158,18 +157,28 @@ def test_export_node(self): elif isinstance(attr[0], str): assert [s.decode() for s in onnx_attr.strings] == attr else: - raise AssertionError("Unrecognized list attribute: ({:}: {:}) of type: {:}".format(name, attr, type(attr))) + raise AssertionError( + "Unrecognized list attribute: ({:}: {:}) of type: {:}".format(name, attr, type(attr)) + ) else: raise AssertionError("Unrecognized attribute: ({:}: {:}) of type: {:}".format(name, attr, type(attr))) - # See test_importers for import correctness checks # This function first imports an ONNX graph, and then re-exports it with no changes. # The exported ONNX graph should exactly match the original. - @pytest.mark.parametrize("model", - [identity_model(), lstm_model(), scan_model(), dim_param_model(), - initializer_is_output_model(), nested_dup_names(), ext_weights()], - ids=lambda model: str(model)) + @pytest.mark.parametrize( + "model", + [ + identity_model(), + lstm_model(), + scan_model(), + dim_param_model(), + initializer_is_output_model(), + nested_dup_names(), + ext_weights(), + ], + ids=lambda model: str(model), + ) def test_export_graph(self, model): onnx_graph = model.load().graph graph = OnnxImporter.import_graph(onnx_graph) diff --git a/tools/onnx-graphsurgeon/tests/test_importers.py b/tools/onnx-graphsurgeon/tests/test_importers.py index 8c7fbb25..e2e50b1d 100644 --- a/tools/onnx-graphsurgeon/tests/test_importers.py +++ b/tools/onnx-graphsurgeon/tests/test_importers.py @@ -25,12 +25,19 @@ from onnx_graphsurgeon.ir.tensor import Constant, Variable from onnx_graphsurgeon.logger.logger import G_LOGGER -from onnx_models import (dim_param_model, ext_weights, identity_model, - initializer_is_output_model, lstm_model, - nested_dup_names, scan_model) +from onnx_models import ( + dim_param_model, + ext_weights, + identity_model, + initializer_is_output_model, + lstm_model, + nested_dup_names, + scan_model, +) G_LOGGER.severity = G_LOGGER.ULTRA_VERBOSE + class TestOnnxImporter(object): def test_import_variable_tensor(self): name = "test0" @@ -42,7 +49,6 @@ def test_import_variable_tensor(self): assert tensor.dtype == np.float32 assert tuple(tensor.shape) == shape - def test_import_constant_tensor(self): shape = (3, 3, 3) dtype = np.float32 @@ -52,7 +58,6 @@ def test_import_constant_tensor(self): assert tensor.dtype == dtype assert tuple(tensor.shape) == shape - def test_import_tensor_unknown_metadata(self): name = "test0" onnx_tensor = onnx.helper.make_empty_tensor_value_info(name) @@ -60,7 +65,6 @@ def test_import_tensor_unknown_metadata(self): assert type(tensor) == Variable assert tensor.name == name - # An empty string in `dim_param` should be treated like a dynamic dimension def test_import_empty_dim_param_tensor(self): shape = (1, 2, "non-empty", "") @@ -69,7 +73,6 @@ def test_import_empty_dim_param_tensor(self): assert type(tensor) == Variable assert tuple(tensor.shape) == shape - # Sometimes, tensor shape is not known, in which case we shouldn't import it def test_import_unknown_shape_tensor(self): shape = None @@ -78,10 +81,9 @@ def test_import_unknown_shape_tensor(self): assert type(tensor) == Variable assert tensor.shape is None - # Scalars can be represented in ONNX with a dim that includes neither a dim_param nor dim_value def test_import_empty_dim_tensor(self): - shape = (None, ) + shape = (None,) onnx_tensor = onnx.helper.make_tensor_value_info("test0", onnx.TensorProto.FLOAT, shape) onnx_tensor.type.tensor_type.shape.dim[0].ClearField("dim_value") onnx_tensor.type.tensor_type.shape.dim[0].ClearField("dim_param") @@ -90,7 +92,6 @@ def test_import_empty_dim_tensor(self): assert type(tensor) == Variable assert tuple(tensor.shape) == shape - # TODO: Test all attribute types - missing graph def test_import_node(self): op = "Test" @@ -105,7 +106,18 @@ def test_import_node(self): ints_attr = [4, 3, 2, 1] strings_attr = ["constant", "and", "variable"] - onnx_node = onnx.helper.make_node(op, inputs, outputs, float_attr=float_attr, int_attr=int_attr, str_attr=str_attr, tensor_attr=tensor_attr, floats_attr=floats_attr, ints_attr=ints_attr, strings_attr=strings_attr) + onnx_node = onnx.helper.make_node( + op, + inputs, + outputs, + float_attr=float_attr, + int_attr=int_attr, + str_attr=str_attr, + tensor_attr=tensor_attr, + floats_attr=floats_attr, + ints_attr=ints_attr, + strings_attr=strings_attr, + ) node = OnnxImporter.import_node(onnx_node, OrderedDict(), OrderedDict()) assert node.op == op assert node.attrs["float_attr"] == float_attr @@ -117,22 +129,30 @@ def test_import_node(self): assert node.attrs["ints_attr"] == ints_attr assert node.attrs["strings_attr"] == strings_attr - - @pytest.mark.parametrize("model", - [identity_model(), lstm_model(), scan_model(), dim_param_model(), - initializer_is_output_model(), nested_dup_names(), ext_weights()], - ids=lambda model: str(model)) + @pytest.mark.parametrize( + "model", + [ + identity_model(), + lstm_model(), + scan_model(), + dim_param_model(), + initializer_is_output_model(), + nested_dup_names(), + ext_weights(), + ], + ids=lambda model: str(model), + ) def test_import_graph(self, model): graph = OnnxImporter.import_graph(model.load().graph) model.assert_equal(graph) - def test_import_graph_value_info(self): model = onnx.shape_inference.infer_shapes(identity_model().load()) graph = OnnxImporter.import_graph(model.graph) tensors = graph.tensors() - assert all([type(tensor) == Variable and tensor.dtype is not None and tensor.shape for tensor in tensors.values()]) - + assert all( + [type(tensor) == Variable and tensor.dtype is not None and tensor.shape for tensor in tensors.values()] + ) def test_import_graph_tensor_map_preserved(self): model = identity_model() @@ -141,13 +161,11 @@ def test_import_graph_tensor_map_preserved(self): assert len(tensor_map) == 0 model.assert_equal(graph) - def test_import_graph_with_initializer(self): model = lstm_model() graph = OnnxImporter.import_graph(model.load().graph) model.assert_equal(graph) - def test_import_graph_with_dim_param(self): model = dim_param_model() graph = OnnxImporter.import_graph(model.load().graph) diff --git a/tools/onnx-graphsurgeon/tests/test_ir.py b/tools/onnx-graphsurgeon/tests/test_ir.py index 903006c3..f0f24197 100644 --- a/tools/onnx-graphsurgeon/tests/test_ir.py +++ b/tools/onnx-graphsurgeon/tests/test_ir.py @@ -23,6 +23,7 @@ G_LOGGER.severity = G_LOGGER.ULTRA_VERBOSE + class TensorBaseTests(object): def test_can_convert_in_place_to_constant(self): tensor = self.tensor.to_constant(values=np.ones((1, 3, 5, 5), dtype=np.float64)) @@ -117,7 +118,9 @@ def test_equals_name_mismatch(self): class TestConstant(TensorBaseTests): def setup_method(self): self.tensor = Constant(name="test_tensor", values=np.ones((1, 3, 5, 5), dtype=np.float64)) - self.input_node = Node(op="Add", outputs=[self.tensor]) # Doesn't make sense for Constants, but needed to make base tests happy. + self.input_node = Node( + op="Add", outputs=[self.tensor] + ) # Doesn't make sense for Constants, but needed to make base tests happy. self.output_node = Node(op="Add", inputs=[self.tensor]) def test_can_get_shape(self): @@ -194,7 +197,9 @@ def test_i_multiple_inputs(self): intermediate_tensor2 = Variable(name="intermediate2") input_node = Node(op="Add", name="Input", inputs=[self.input_tensor], outputs=[intermediate_tensor]) input_node2 = Node(op="Add", name="Input2", inputs=[self.input_tensor], outputs=[intermediate_tensor2]) - output_node = Node(op="Add", name="Out", inputs=[intermediate_tensor, intermediate_tensor2], outputs=[self.output_tensor]) + output_node = Node( + op="Add", name="Out", inputs=[intermediate_tensor, intermediate_tensor2], outputs=[self.output_tensor] + ) assert output_node.i() == input_node assert output_node.i(1) == input_node2 @@ -216,7 +221,9 @@ def test_o_multiple_outputs(self): class TestNodeIO(object): def setup_method(self, field_names): - self.tensors = [Variable(name="test_tensor_{:}".format(i), dtype=np.float32, shape=(1, 3, 224, 224)) for i in range(10)] + self.tensors = [ + Variable(name="test_tensor_{:}".format(i), dtype=np.float32, shape=(1, 3, 224, 224)) for i in range(10) + ] self.node = Node(op="Dummy") def get_lists(self, field_names):