diff --git a/src/sparsezoo/analyze/analysis.py b/src/sparsezoo/analyze/analysis.py index 698245dd..319088a9 100644 --- a/src/sparsezoo/analyze/analysis.py +++ b/src/sparsezoo/analyze/analysis.py @@ -17,8 +17,10 @@ """ import copy +import logging from collections import defaultdict from dataclasses import dataclass +from pathlib import Path from typing import Any, Dict, List, Optional, Union import numpy @@ -27,6 +29,7 @@ from onnx import ModelProto, NodeProto from pydantic import BaseModel, Field, PositiveFloat, PositiveInt +from sparsezoo import Model from sparsezoo.analyze.utils.models import ( DenseSparseOps, NodeCounts, @@ -68,6 +71,8 @@ "ModelAnalysis", ] +_LOGGER = logging.getLogger() + class YAMLSerializableBaseModel(BaseModel): """ @@ -964,6 +969,50 @@ def from_onnx(cls, onnx_file_path: Union[str, ModelProto]): nodes=node_analyses, ) + @classmethod + def create(cls, file_path: Union[str, ModelProto]) -> "ModelAnalysis": + """ + Factory method to create a model analysis object from an onnx filepath, + sparsezoo stub, deployment directory, or a yaml file/raw string representing + a `ModelAnalysis` object + + :param file_path: An instantiated ModelProto object, or path to an onnx + model, or SparseZoo stub, or path to a deployment directory, or path + to a yaml file or a raw yaml string representing a `ModelAnalysis` + object. This is used to create a new ModelAnalysis object + :returns: The created ModelAnalysis object + """ + if not isinstance(file_path, (str, ModelProto, Path)): + raise ValueError( + f"Invalid file_path type {type(file_path)} passed to " + f"ModelAnalysis.create(...)" + ) + + if isinstance(file_path, ModelProto): + return ModelAnalysis.from_onnx(onnx_file_path=file_path) + + if Path(file_path).is_file(): + return ( + ModelAnalysis.parse_yaml_file(file_path=file_path) + if Path(file_path).suffix == ".yaml" + else ModelAnalysis.from_onnx(onnx_file_path=file_path) + ) + if Path(file_path).is_dir(): + _LOGGER.info(f"Loading `model.onnx` from deployment directory {file_path}") + return ModelAnalysis.from_onnx(Path(file_path) / "model.onnx") + + if file_path.startswith("zoo:"): + return ModelAnalysis.from_onnx( + Model(file_path).deployment.get_file("model.onnx").path + ) + + if isinstance(file_path, str): + return ModelAnalysis.parse_yaml_raw(yaml_raw=file_path) + + raise ValueError( + f"Invalid argument file_path {file_path} to create ModelAnalysis" + ) + def summary(self) -> Dict[str, Any]: """ :return: A dict like object with summary of current analysis diff --git a/src/sparsezoo/analyze_cli.py b/src/sparsezoo/analyze_cli.py index e21aba83..2f635831 100644 --- a/src/sparsezoo/analyze_cli.py +++ b/src/sparsezoo/analyze_cli.py @@ -61,12 +61,10 @@ """ import copy import logging -from pathlib import Path from typing import Optional import click import pandas as pd -from sparsezoo import Model from sparsezoo.analyze import ModelAnalysis @@ -101,10 +99,8 @@ def main(model_path: str, save: Optional[str], **kwargs): f"--{unimplemented_feat} has not been implemented yet" ) - model_file_path = _get_model_file_path(model_path=model_path) - LOGGER.info("Starting Analysis ...") - analysis = ModelAnalysis.from_onnx(model_file_path) + analysis = ModelAnalysis.create(model_path) LOGGER.info("Analysis complete, collating results...") summary = analysis.summary() @@ -117,18 +113,6 @@ def main(model_path: str, save: Optional[str], **kwargs): analysis.yaml(file_path=save) -def _get_model_file_path(model_path: str): - if model_path.startswith("zoo:"): - LOGGER.info(f"Downloading files from SparseZoo: '{model_path}'") - model = Model(model_path) - model_path = Path(model.deployment.get_file("model.onnx").path) - elif Path(model_path).is_file(): - model_path = model_path - else: - model_path = Path(model_path) / "model.onnx" - return model_path - - def _display_summary_as_table(summary): summary_copy = copy.copy(summary) print(f"MODEL: {summary_copy.pop('MODEL')}", end="\n\n") diff --git a/tests/sparsezoo/analyze/test_model_analysis_creation.py b/tests/sparsezoo/analyze/test_model_analysis_creation.py new file mode 100644 index 00000000..796d8e51 --- /dev/null +++ b/tests/sparsezoo/analyze/test_model_analysis_creation.py @@ -0,0 +1,76 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import onnx +import pytest + +from sparsezoo import Model +from sparsezoo.analyze import ModelAnalysis + + +def onnx_stub(): + return ( + "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/" + "imagenet/pruned95_quant-none" + ) + + +def onnx_deployment_dir(): + return Model(onnx_stub()).deployment.path + + +def onnx_local_path(): + return str(Path(onnx_deployment_dir()) / "model.onnx") + + +def onnx_model(): + return onnx.load(onnx_local_path()) + + +@pytest.mark.parametrize( + "file_path, should_error", + [ + (onnx_stub(), False), + (onnx_deployment_dir(), False), + (onnx_local_path(), False), + (onnx_model(), False), + (1, True), + ], +) +def test_create(file_path, should_error): + if should_error: + with pytest.raises(ValueError, match="Invalid"): + ModelAnalysis.create(file_path) + else: + analysis = ModelAnalysis.create(file_path) + assert isinstance(analysis, ModelAnalysis) + + +@pytest.mark.parametrize( + "model_path", + [ + onnx_local_path(), + ], +) +def test_yaml_serialization(model_path, tmp_path): + analysis = ModelAnalysis.create(file_path=model_path) + + yaml_file = str(tmp_path / "quantized-resnet.yaml") + analysis.yaml(file_path=yaml_file) + + analysis_from_yaml = ModelAnalysis.create(file_path=yaml_file) + + assert analysis.yaml() == analysis_from_yaml.yaml()