Skip to content

Commit

Permalink
update for failing tests; add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Dipika Sikka committed Feb 8, 2024
1 parent c27bd28 commit d26b23b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
2 changes: 2 additions & 0 deletions src/sparsezoo/analyze_v2/model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,13 @@ def analyze(path: str, download_path: Optional[str] = None) -> "ModelAnalysis":
else:
raise ValueError(f"{path} is not a valid argument")

# just need graph to get shape information; dont load external data
node_shapes, _ = extract_node_shapes_and_dtypes(onnx_model, onnx_model_path)

summary_analysis = SummaryAnalysis()
node_analyses = {}

# load external data for node analysis
onnx_model = onnx.load(onnx_model_path)
model_graph = ONNXGraph(onnx_model)

Expand Down
30 changes: 19 additions & 11 deletions src/sparsezoo/utils/node_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -61,7 +61,7 @@


def extract_nodes_shapes_and_dtypes_ort(
model: ModelProto, path: str
model: ModelProto, path: Optional[str] = None
) -> Tuple[Dict[str, List[List[int]]], Dict[str, numpy.dtype]]:
"""
Creates a modified model to expose intermediate outputs and runs an ONNX Runtime
Expand All @@ -81,16 +81,24 @@ def extract_nodes_shapes_and_dtypes_ort(
)
model_copy.graph.output.append(intermediate_layer_value_info)

parent_dir = Path(path).parent.absolute()
new_path = parent_dir / "model_new.onnx"
onnx.save(model_copy, new_path, save_as_external_data=True)

# using the ModelProto does not work for large models when running the session
# have to save again and pass the new path to the inference session
sess_options = onnxruntime.SessionOptions()
sess_options.log_severity_level = 3

sess = onnxruntime.InferenceSession(
new_path, sess_options, providers=onnxruntime.get_available_providers()
)
if path:
parent_dir = Path(path).parent.absolute()
new_path = parent_dir / "model_new.onnx"
onnx.save(model_copy, new_path, save_as_external_data=True)
sess = onnxruntime.InferenceSession(
new_path, sess_options, providers=onnxruntime.get_available_providers()
)
else:
sess = onnxruntime.InferenceSession(
model_copy.SerializeToString(),
sess_options,
providers=onnxruntime.get_available_providers(),
)

input_value_dict = {}
for input in model_copy.graph.input:
Expand Down Expand Up @@ -173,7 +181,7 @@ def extract_nodes_shapes_and_dtypes_shape_inference(


def extract_nodes_shapes_and_dtypes(
model: ModelProto, path: str
model: ModelProto, path: Optional[str] = None
) -> Tuple[Dict[str, List[List[int]]], Dict[str, numpy.dtype]]:
"""
Uses ONNX Runtime or shape inference to infer output shapes and dtypes from model
Expand Down Expand Up @@ -314,7 +322,7 @@ def collate_output_dtypes(


def extract_node_shapes_and_dtypes(
model: ModelProto, path: str
model: ModelProto, path: Optional[str] = None
) -> Tuple[Dict[str, NodeShape], Dict[str, NodeDataType]]:
"""
Extracts the shape and dtype information for each node as NodeShape objects
Expand Down

0 comments on commit d26b23b

Please sign in to comment.