Skip to content

Commit

Permalink
[sparsezoo.analyze] Fix pathway such that it works for larger models (#…
Browse files Browse the repository at this point in the history
…437)

* fix analyze to work with larger models

* update for failing tests; add comments

* Update src/sparsezoo/utils/onnx/external_data.py

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>

---------

Co-authored-by: Dipika Sikka <dipikasikka1@gmail.coom>
Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 9, 2024
1 parent e6b12f6 commit c1a096f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 15 deletions.
13 changes: 9 additions & 4 deletions src/sparsezoo/analyze_v2/model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,25 @@ def analyze(path: str, download_path: Optional[str] = None) -> "ModelAnalysis":
:param path: .onnx path or stub
"""
if path.endswith(".onnx"):
onnx_model = load_model(path)
onnx_model = load_model(path, load_external_data=False)
onnx_model_path = path
elif is_stub(path):
model = Model(path, download_path)
onnx_model_path = model.onnx_model.path
onnx_model = onnx.load(onnx_model_path)
onnx_model = onnx.load(onnx_model_path, load_external_data=False)
else:
raise ValueError(f"{path} is not a valid argument")

model_graph = ONNXGraph(onnx_model)
node_shapes, _ = extract_node_shapes_and_dtypes(model_graph.model)
# just need graph to get shape information; dont load external data
node_shapes, _ = extract_node_shapes_and_dtypes(onnx_model, onnx_model_path)

summary_analysis = SummaryAnalysis()
node_analyses = {}

# load external data for node analysis
onnx_model = onnx.load(onnx_model_path)
model_graph = ONNXGraph(onnx_model)

for graph_order, node in enumerate(model_graph.nodes):
node_id = extract_node_id(node)
node_shape = node_shapes.get(node_id)
Expand Down
35 changes: 26 additions & 9 deletions src/sparsezoo/utils/node_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import logging
from copy import deepcopy
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -60,13 +61,14 @@


def extract_nodes_shapes_and_dtypes_ort(
model: ModelProto,
model: ModelProto, path: Optional[str] = None
) -> Tuple[Dict[str, List[List[int]]], Dict[str, numpy.dtype]]:
"""
Creates a modified model to expose intermediate outputs and runs an ONNX Runtime
InferenceSession to obtain the output shape of each node.
:param model: an ONNX model
:param path: absolute path to the original onnx model
:return: a list of NodeArg with their shape exposed
"""
import onnxruntime
Expand All @@ -79,11 +81,24 @@ def extract_nodes_shapes_and_dtypes_ort(
)
model_copy.graph.output.append(intermediate_layer_value_info)

# using the ModelProto does not work for large models when running the session
# have to save again and pass the new path to the inference session
sess_options = onnxruntime.SessionOptions()
sess_options.log_severity_level = 3
sess = onnxruntime.InferenceSession(
model_copy.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
)

if path:
parent_dir = Path(path).parent.absolute()
new_path = parent_dir / "model_new.onnx"
onnx.save(model_copy, new_path, save_as_external_data=True)
sess = onnxruntime.InferenceSession(
new_path, sess_options, providers=onnxruntime.get_available_providers()
)
else:
sess = onnxruntime.InferenceSession(
model_copy.SerializeToString(),
sess_options,
providers=onnxruntime.get_available_providers(),
)

input_value_dict = {}
for input in model_copy.graph.input:
Expand Down Expand Up @@ -166,19 +181,20 @@ def extract_nodes_shapes_and_dtypes_shape_inference(


def extract_nodes_shapes_and_dtypes(
model: ModelProto,
model: ModelProto, path: Optional[str] = None
) -> Tuple[Dict[str, List[List[int]]], Dict[str, numpy.dtype]]:
"""
Uses ONNX Runtime or shape inference to infer output shapes and dtypes from model
:param model: model to extract output values from
:param path: absolute path to the original onnx model
:return: output shapes and output data types
"""
output_shapes = None
output_dtypes = None

try:
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes_ort(model)
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes_ort(model, path)
except Exception as err:
_LOGGER.warning(f"Extracting shapes using ONNX Runtime session failed: {err}")

Expand Down Expand Up @@ -306,18 +322,19 @@ def collate_output_dtypes(


def extract_node_shapes_and_dtypes(
model: ModelProto,
model: ModelProto, path: Optional[str] = None
) -> Tuple[Dict[str, NodeShape], Dict[str, NodeDataType]]:
"""
Extracts the shape and dtype information for each node as NodeShape objects
and numpy dtypes.
:param model: the loaded onnx.ModelProto to extract node shape information from
:param path: absolute path to the original onnx model
:return: a mapping of node id to a NodeShape object
"""

# Obtains output shapes for each model's node
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes(model)
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes(model, path)

# Package output shapes into each node's inputs and outputs
node_shapes = collate_output_shapes(model, output_shapes)
Expand Down
8 changes: 6 additions & 2 deletions src/sparsezoo/utils/onnx/external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,23 @@ def validate_onnx(model: Union[str, ModelProto]):
raise ValueError(f"Invalid onnx model: {err}")


def load_model(model: Union[str, ModelProto, Path]) -> ModelProto:
def load_model(
model: Union[str, ModelProto, Path], load_external_data: bool = True
) -> ModelProto:
"""
Load an ONNX model from an onnx model file path. If a ModelProto
is given, then it is returned.
:param model: the model proto or path to the model ONNX file to check for loading
:param load_external_data: if a path is given, whether or not to also load the
external model data
:return: the loaded ONNX ModelProto
"""
if isinstance(model, ModelProto):
return model

if isinstance(model, (Path, str)):
return onnx.load(clean_path(model))
return onnx.load(clean_path(model), load_external_data=load_external_data)

raise TypeError(f"unknown type given for model: {type(model)}")

Expand Down

0 comments on commit c1a096f

Please sign in to comment.