Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sparsezoo.analyze] Fix pathway such that it works for larger models #437

Merged
merged 4 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/sparsezoo/analyze_v2/model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,25 @@ def analyze(path: str, download_path: Optional[str] = None) -> "ModelAnalysis":
:param path: .onnx path or stub
"""
if path.endswith(".onnx"):
onnx_model = load_model(path)
onnx_model = load_model(path, load_external_data=False)
onnx_model_path = path
elif is_stub(path):
model = Model(path, download_path)
onnx_model_path = model.onnx_model.path
onnx_model = onnx.load(onnx_model_path)
onnx_model = onnx.load(onnx_model_path, load_external_data=False)
else:
raise ValueError(f"{path} is not a valid argument")

model_graph = ONNXGraph(onnx_model)
node_shapes, _ = extract_node_shapes_and_dtypes(model_graph.model)
# just need graph to get shape information; dont load external data
node_shapes, _ = extract_node_shapes_and_dtypes(onnx_model, onnx_model_path)
dsikka marked this conversation as resolved.
Show resolved Hide resolved

summary_analysis = SummaryAnalysis()
node_analyses = {}

# load external data for node analysis
onnx_model = onnx.load(onnx_model_path)
model_graph = ONNXGraph(onnx_model)

for graph_order, node in enumerate(model_graph.nodes):
node_id = extract_node_id(node)
node_shape = node_shapes.get(node_id)
Expand Down
35 changes: 26 additions & 9 deletions src/sparsezoo/utils/node_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import logging
from copy import deepcopy
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -60,13 +61,14 @@


def extract_nodes_shapes_and_dtypes_ort(
model: ModelProto,
model: ModelProto, path: Optional[str] = None
) -> Tuple[Dict[str, List[List[int]]], Dict[str, numpy.dtype]]:
"""
Creates a modified model to expose intermediate outputs and runs an ONNX Runtime
InferenceSession to obtain the output shape of each node.

:param model: an ONNX model
:param path: absolute path to the original onnx model
:return: a list of NodeArg with their shape exposed
"""
import onnxruntime
Expand All @@ -79,11 +81,24 @@ def extract_nodes_shapes_and_dtypes_ort(
)
model_copy.graph.output.append(intermediate_layer_value_info)

# using the ModelProto does not work for large models when running the session
# have to save again and pass the new path to the inference session
sess_options = onnxruntime.SessionOptions()
sess_options.log_severity_level = 3
sess = onnxruntime.InferenceSession(
model_copy.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
)

if path:
parent_dir = Path(path).parent.absolute()
new_path = parent_dir / "model_new.onnx"
onnx.save(model_copy, new_path, save_as_external_data=True)
sess = onnxruntime.InferenceSession(
new_path, sess_options, providers=onnxruntime.get_available_providers()
)
else:
sess = onnxruntime.InferenceSession(
model_copy.SerializeToString(),
sess_options,
providers=onnxruntime.get_available_providers(),
)

input_value_dict = {}
for input in model_copy.graph.input:
Expand Down Expand Up @@ -166,19 +181,20 @@ def extract_nodes_shapes_and_dtypes_shape_inference(


def extract_nodes_shapes_and_dtypes(
model: ModelProto,
model: ModelProto, path: Optional[str] = None
) -> Tuple[Dict[str, List[List[int]]], Dict[str, numpy.dtype]]:
"""
Uses ONNX Runtime or shape inference to infer output shapes and dtypes from model

:param model: model to extract output values from
:param path: absolute path to the original onnx model
:return: output shapes and output data types
"""
output_shapes = None
output_dtypes = None

try:
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes_ort(model)
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes_ort(model, path)
except Exception as err:
_LOGGER.warning(f"Extracting shapes using ONNX Runtime session failed: {err}")

Expand Down Expand Up @@ -306,18 +322,19 @@ def collate_output_dtypes(


def extract_node_shapes_and_dtypes(
model: ModelProto,
model: ModelProto, path: Optional[str] = None
) -> Tuple[Dict[str, NodeShape], Dict[str, NodeDataType]]:
"""
Extracts the shape and dtype information for each node as NodeShape objects
and numpy dtypes.

:param model: the loaded onnx.ModelProto to extract node shape information from
:param path: absolute path to the original onnx model
:return: a mapping of node id to a NodeShape object
"""

# Obtains output shapes for each model's node
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes(model)
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes(model, path)

# Package output shapes into each node's inputs and outputs
node_shapes = collate_output_shapes(model, output_shapes)
Expand Down
8 changes: 6 additions & 2 deletions src/sparsezoo/utils/onnx/external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,23 @@ def validate_onnx(model: Union[str, ModelProto]):
raise ValueError(f"Invalid onnx model: {err}")


def load_model(model: Union[str, ModelProto, Path]) -> ModelProto:
def load_model(
model: Union[str, ModelProto, Path], load_external_data: bool = True
) -> ModelProto:
"""
Load an ONNX model from an onnx model file path. If a ModelProto
is given, then it is returned.

:param model: the model proto or path to the model ONNX file to check for loading
:param load_external_data: if a path is given, whether or not to also load the
external model data
:return: the loaded ONNX ModelProto
"""
if isinstance(model, ModelProto):
return model

if isinstance(model, (Path, str)):
return onnx.load(clean_path(model))
return onnx.load(clean_path(model), load_external_data=load_external_data)

raise TypeError(f"unknown type given for model: {type(model)}")

Expand Down
Loading