Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

ModelAnalysis.create(...) #281

Merged
merged 14 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/sparsezoo/analyze/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
"""

import copy
import logging
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import numpy
Expand All @@ -27,6 +29,7 @@
from onnx import ModelProto, NodeProto
from pydantic import BaseModel, Field, PositiveFloat, PositiveInt

from sparsezoo import Model
from sparsezoo.analyze.utils.models import (
DenseSparseOps,
NodeCounts,
Expand Down Expand Up @@ -68,6 +71,8 @@
"ModelAnalysis",
]

_LOGGER = logging.getLogger()


class YAMLSerializableBaseModel(BaseModel):
"""
Expand Down Expand Up @@ -964,6 +969,50 @@ def from_onnx(cls, onnx_file_path: Union[str, ModelProto]):
nodes=node_analyses,
)

@classmethod
def create(cls, file_path: Union[str, ModelProto]) -> "ModelAnalysis":
"""
Factory method to create a model analysis object from an onnx filepath,
sparsezoo stub, deployment directory, or a yaml file/raw string representing
a `ModelAnalysis` object

:param file_path: An instantiated ModelProto object, or path to an onnx
model, or SparseZoo stub, or path to a deployment directory, or path
to a yaml file or a raw yaml string representing a `ModelAnalysis`
object. This is used to create a new ModelAnalysis object
:returns: The created ModelAnalysis object
"""
if not isinstance(file_path, (str, ModelProto, Path)):
raise ValueError(
f"Invalid file_path type {type(file_path)} passed to "
f"ModelAnalysis.create(...)"
)

if isinstance(file_path, ModelProto):
return ModelAnalysis.from_onnx(onnx_file_path=file_path)

if Path(file_path).is_file():
return (
ModelAnalysis.parse_yaml_file(file_path=file_path)
if Path(file_path).suffix == ".yaml"
else ModelAnalysis.from_onnx(onnx_file_path=file_path)
)
if Path(file_path).is_dir():
_LOGGER.info(f"Loading `model.onnx` from deployment directory {file_path}")
return ModelAnalysis.from_onnx(Path(file_path) / "model.onnx")

if file_path.startswith("zoo:"):
return ModelAnalysis.from_onnx(
Model(file_path).deployment.get_file("model.onnx").path
)

if isinstance(file_path, str):
return ModelAnalysis.parse_yaml_raw(yaml_raw=file_path)

raise ValueError(
f"Invalid argument file_path {file_path} to create ModelAnalysis"
)

def summary(self) -> Dict[str, Any]:
"""
:return: A dict like object with summary of current analysis
Expand Down
18 changes: 1 addition & 17 deletions src/sparsezoo/analyze_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,10 @@
"""
import copy
import logging
from pathlib import Path
from typing import Optional

import click
import pandas as pd
from sparsezoo import Model
from sparsezoo.analyze import ModelAnalysis


Expand Down Expand Up @@ -101,10 +99,8 @@ def main(model_path: str, save: Optional[str], **kwargs):
f"--{unimplemented_feat} has not been implemented yet"
)

model_file_path = _get_model_file_path(model_path=model_path)

LOGGER.info("Starting Analysis ...")
analysis = ModelAnalysis.from_onnx(model_file_path)
analysis = ModelAnalysis.create(model_path)
LOGGER.info("Analysis complete, collating results...")

summary = analysis.summary()
Expand All @@ -117,18 +113,6 @@ def main(model_path: str, save: Optional[str], **kwargs):
analysis.yaml(file_path=save)


def _get_model_file_path(model_path: str):
if model_path.startswith("zoo:"):
LOGGER.info(f"Downloading files from SparseZoo: '{model_path}'")
model = Model(model_path)
model_path = Path(model.deployment.get_file("model.onnx").path)
elif Path(model_path).is_file():
model_path = model_path
else:
model_path = Path(model_path) / "model.onnx"
return model_path


def _display_summary_as_table(summary):
summary_copy = copy.copy(summary)
print(f"MODEL: {summary_copy.pop('MODEL')}", end="\n\n")
Expand Down
76 changes: 76 additions & 0 deletions tests/sparsezoo/analyze/test_model_analysis_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import onnx
import pytest

from sparsezoo import Model
from sparsezoo.analyze import ModelAnalysis


def onnx_stub():
return (
"zoo:cv/classification/resnet_v1-50/pytorch/sparseml/"
"imagenet/pruned95_quant-none"
)


def onnx_deployment_dir():
return Model(onnx_stub()).deployment.path


def onnx_local_path():
return str(Path(onnx_deployment_dir()) / "model.onnx")


def onnx_model():
return onnx.load(onnx_local_path())


@pytest.mark.parametrize(
"file_path, should_error",
[
(onnx_stub(), False),
(onnx_deployment_dir(), False),
(onnx_local_path(), False),
(onnx_model(), False),
(1, True),
],
)
def test_create(file_path, should_error):
if should_error:
with pytest.raises(ValueError, match="Invalid"):
ModelAnalysis.create(file_path)
else:
analysis = ModelAnalysis.create(file_path)
assert isinstance(analysis, ModelAnalysis)


@pytest.mark.parametrize(
"model_path",
[
onnx_local_path(),
],
)
def test_yaml_serialization(model_path, tmp_path):
analysis = ModelAnalysis.create(file_path=model_path)

yaml_file = str(tmp_path / "quantized-resnet.yaml")
analysis.yaml(file_path=yaml_file)

analysis_from_yaml = ModelAnalysis.create(file_path=yaml_file)

assert analysis.yaml() == analysis_from_yaml.yaml()