diff --git a/src/sparsezoo/analyze_v2/model_analysis.py b/src/sparsezoo/analyze_v2/model_analysis.py index 328a1a01..39420b12 100644 --- a/src/sparsezoo/analyze_v2/model_analysis.py +++ b/src/sparsezoo/analyze_v2/model_analysis.py @@ -146,20 +146,25 @@ def analyze(path: str, download_path: Optional[str] = None) -> "ModelAnalysis": :param path: .onnx path or stub """ if path.endswith(".onnx"): - onnx_model = load_model(path) + onnx_model = load_model(path, load_external_data=False) + onnx_model_path = path elif is_stub(path): model = Model(path, download_path) onnx_model_path = model.onnx_model.path - onnx_model = onnx.load(onnx_model_path) + onnx_model = onnx.load(onnx_model_path, load_external_data=False) else: raise ValueError(f"{path} is not a valid argument") - model_graph = ONNXGraph(onnx_model) - node_shapes, _ = extract_node_shapes_and_dtypes(model_graph.model) + # just need graph to get shape information; dont load external data + node_shapes, _ = extract_node_shapes_and_dtypes(onnx_model, onnx_model_path) summary_analysis = SummaryAnalysis() node_analyses = {} + # load external data for node analysis + onnx_model = onnx.load(onnx_model_path) + model_graph = ONNXGraph(onnx_model) + for graph_order, node in enumerate(model_graph.nodes): node_id = extract_node_id(node) node_shape = node_shapes.get(node_id) diff --git a/src/sparsezoo/utils/node_inference.py b/src/sparsezoo/utils/node_inference.py index 570a06d2..30a2675a 100644 --- a/src/sparsezoo/utils/node_inference.py +++ b/src/sparsezoo/utils/node_inference.py @@ -18,7 +18,8 @@ import logging from copy import deepcopy -from typing import Any, Dict, List, NamedTuple, Tuple, Union +from pathlib import Path +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union import numpy import onnx @@ -60,13 +61,14 @@ def extract_nodes_shapes_and_dtypes_ort( - model: ModelProto, + model: ModelProto, path: Optional[str] = None ) -> Tuple[Dict[str, List[List[int]]], Dict[str, numpy.dtype]]: """ Creates a modified model to expose intermediate outputs and runs an ONNX Runtime InferenceSession to obtain the output shape of each node. :param model: an ONNX model + :param path: absolute path to the original onnx model :return: a list of NodeArg with their shape exposed """ import onnxruntime @@ -79,11 +81,24 @@ def extract_nodes_shapes_and_dtypes_ort( ) model_copy.graph.output.append(intermediate_layer_value_info) + # using the ModelProto does not work for large models when running the session + # have to save again and pass the new path to the inference session sess_options = onnxruntime.SessionOptions() sess_options.log_severity_level = 3 - sess = onnxruntime.InferenceSession( - model_copy.SerializeToString(), sess_options, providers=["CPUExecutionProvider"] - ) + + if path: + parent_dir = Path(path).parent.absolute() + new_path = parent_dir / "model_new.onnx" + onnx.save(model_copy, new_path, save_as_external_data=True) + sess = onnxruntime.InferenceSession( + new_path, sess_options, providers=onnxruntime.get_available_providers() + ) + else: + sess = onnxruntime.InferenceSession( + model_copy.SerializeToString(), + sess_options, + providers=onnxruntime.get_available_providers(), + ) input_value_dict = {} for input in model_copy.graph.input: @@ -166,19 +181,20 @@ def extract_nodes_shapes_and_dtypes_shape_inference( def extract_nodes_shapes_and_dtypes( - model: ModelProto, + model: ModelProto, path: Optional[str] = None ) -> Tuple[Dict[str, List[List[int]]], Dict[str, numpy.dtype]]: """ Uses ONNX Runtime or shape inference to infer output shapes and dtypes from model :param model: model to extract output values from + :param path: absolute path to the original onnx model :return: output shapes and output data types """ output_shapes = None output_dtypes = None try: - output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes_ort(model) + output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes_ort(model, path) except Exception as err: _LOGGER.warning(f"Extracting shapes using ONNX Runtime session failed: {err}") @@ -306,18 +322,19 @@ def collate_output_dtypes( def extract_node_shapes_and_dtypes( - model: ModelProto, + model: ModelProto, path: Optional[str] = None ) -> Tuple[Dict[str, NodeShape], Dict[str, NodeDataType]]: """ Extracts the shape and dtype information for each node as NodeShape objects and numpy dtypes. :param model: the loaded onnx.ModelProto to extract node shape information from + :param path: absolute path to the original onnx model :return: a mapping of node id to a NodeShape object """ # Obtains output shapes for each model's node - output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes(model) + output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes(model, path) # Package output shapes into each node's inputs and outputs node_shapes = collate_output_shapes(model, output_shapes) diff --git a/src/sparsezoo/utils/onnx/external_data.py b/src/sparsezoo/utils/onnx/external_data.py index 4f213bb4..936554a5 100644 --- a/src/sparsezoo/utils/onnx/external_data.py +++ b/src/sparsezoo/utils/onnx/external_data.py @@ -174,19 +174,23 @@ def validate_onnx(model: Union[str, ModelProto]): raise ValueError(f"Invalid onnx model: {err}") -def load_model(model: Union[str, ModelProto, Path]) -> ModelProto: +def load_model( + model: Union[str, ModelProto, Path], load_external_data: bool = True +) -> ModelProto: """ Load an ONNX model from an onnx model file path. If a ModelProto is given, then it is returned. :param model: the model proto or path to the model ONNX file to check for loading + :param load_external_data: if a path is given, whether or not to also load the + external model data :return: the loaded ONNX ModelProto """ if isinstance(model, ModelProto): return model if isinstance(model, (Path, str)): - return onnx.load(clean_path(model)) + return onnx.load(clean_path(model), load_external_data=load_external_data) raise TypeError(f"unknown type given for model: {type(model)}")