Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

default Model.onnx_model to target onnx.model.tar.gz #355

Merged
merged 8 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sparsezoo/analyze/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __sub__(self, other):
my_value = getattr(self, field)
other_value = getattr(other, field)

assert type(my_value) == type(other_value)
assert type(my_value) is type(other_value)
if field == "section_name":
new_fields[field] = my_value
elif isinstance(my_value, str):
Expand Down
10 changes: 9 additions & 1 deletion src/sparsezoo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Directory,
File,
NumpyDirectory,
OnnxGz,
SelectDirectory,
is_directory,
)
Expand Down Expand Up @@ -156,7 +157,14 @@ def __init__(self, source: str, download_path: Optional[str] = None):
stub_params=self.stub_params,
)

self.onnx_model: File = self._file_from_files(files, display_name="model.onnx")
self._onnx_gz: Directory = self._directory_from_files(
files, directory_class=OnnxGz, display_name="model.onnx.tar.gz"
)
self.onnx_model: File = (
self._file_from_files(files, display_name="model.onnx")
if self._onnx_gz is None
else self._onnx_gz # if onnx.model.tar.gz present defer to that file
)

self.analysis: File = self._file_from_files(files, display_name="analysis.yaml")
self.benchmarks: File = self._file_from_files(
Expand Down
8 changes: 7 additions & 1 deletion src/sparsezoo/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ThroughputResults,
ValidationResult,
)
from sparsezoo.objects import Directory, File, NumpyDirectory
from sparsezoo.objects import Directory, File, NumpyDirectory, OnnxGz
from sparsezoo.utils import BASE_API_URL, convert_to_bool, save_numpy


Expand Down Expand Up @@ -575,6 +575,12 @@ def _copy_file_contents(
for _file in file:
copy_path = os.path.join(output_dir, os.path.basename(_file.path))
_copy_and_overwrite(_file.path, copy_path, shutil.copyfile)
elif isinstance(file, OnnxGz):
# copy all contents of unzipped onnx.tar.gz file to top level of output
onnx_gz_path = (
os.path.dirname(file.path) if os.path.isfile(file.path) else file.path
)
shutil.copytree(onnx_gz_path, output_dir, dirs_exist_ok=True)
elif isinstance(file, Directory):
copy_path = os.path.join(output_dir, os.path.basename(file.path))
_copy_and_overwrite(file.path, copy_path, shutil.copytree)
Expand Down
27 changes: 26 additions & 1 deletion src/sparsezoo/objects/directories.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


import logging
import os.path
from collections import OrderedDict
from typing import Dict, List, Optional, Union

Expand All @@ -29,7 +30,11 @@
from sparsezoo.utils import DataLoader, Dataset, load_numpy_list


__all__ = ["NumpyDirectory", "SelectDirectory"]
__all__ = [
"NumpyDirectory",
"SelectDirectory",
"OnnxGz",
]

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -269,3 +274,23 @@ def available(self):
@available.setter
def available(self, value):
self._available = value


class OnnxGz(Directory):
"""
Special class to handle onnx.model.tar.gz files.
Desired behavior is that all information about files included in the tarball are
available however, when the `path` property is accessed, it will point only
to the `model.onnx` as this is the expected behavior for loading an onnx model
with or without external data.
"""

@property
def path(self):
_ = super().path # call self.path to download initial file if not already
if self.is_archive:
self.unzip()
if os.path.isdir(self._path) and "model.onnx" in os.listdir(self._path):
# if unzipped into a directory, refer directly to model.onnx
self._path = os.path.join(self._path, "model.onnx")
return self._path
31 changes: 22 additions & 9 deletions src/sparsezoo/objects/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class Directory(File):
:param path: path of the Directory
:param url: url of the Directory
:param parent_directory: path of the parent Directory
:param force: boolean flag; True to force unzipping of archive files.
Default is False.
"""

def __init__(
Expand All @@ -50,6 +52,7 @@ def __init__(
path: Optional[str] = None,
url: Optional[str] = None,
parent_directory: Optional[str] = None,
force: bool = False,
):

self.files = (
Expand All @@ -63,7 +66,7 @@ def __init__(
)

if self._unpack():
self.unzip()
self.unzip(force=force)

@classmethod
def from_file(cls, file: File) -> "Directory":
Expand Down Expand Up @@ -207,6 +210,8 @@ def get_file(self, file_name: str) -> Optional[File]:
:return: File if found, otherwise None
"""
for file in self.files:
if file is None:
continue
if file.name == file_name:
return file
if isinstance(file, Directory):
Expand Down Expand Up @@ -254,18 +259,23 @@ def gzip(self, archive_directory: Optional[str] = None):
self._path = tar_file_path
self.is_archive = True

def unzip(self, extract_directory: Optional[str] = None):
def unzip(self, extract_directory: Optional[str] = None, force: bool = False):
"""
Extracts a tar archive Directory.
The extracted files would be saved in the parent directory of
`self`, unless `extract_directory` argument is specified

:param extract_directory: the local path to create
folder Directory at (default = None)
:param force: if True, will always unzip, even if the target directory
already exists. Default False
"""
if self._path is None:
# use path property to download so path exists
self._path = self.path
files = []
if extract_directory is None:
extract_directory = os.path.dirname(self.path)
extract_directory = os.path.dirname(self._path)

if not self.is_archive:
raise ValueError(
Expand All @@ -274,14 +284,17 @@ def unzip(self, extract_directory: Optional[str] = None):
)

name = ".".join(self.name.split(".")[:-2])
tar = tarfile.open(self.path, "r")
path = os.path.join(extract_directory, name)

for member in tar.getmembers():
member.name = os.path.basename(member.name)
tar.extract(member=member, path=path)
files.append(File(name=member.name, path=os.path.join(path, member.name)))
tar.close()
if not os.path.exists(path) or force: # do not re-unzip if not forced
tar = tarfile.open(self._path, "r")
for member in tar.getmembers():
member.name = os.path.basename(member.name)
tar.extract(member=member, path=path)
files.append(
File(name=member.name, path=os.path.join(path, member.name))
)
tar.close()

self.name = name
self.files = files
Expand Down
4 changes: 3 additions & 1 deletion tests/sparsezoo/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"logs",
"onnx",
"model.onnx",
"model.onnx.tar.gz",
"recipe",
"sample_inputs.tar.gz",
"sample_originals.tar.gz",
Expand Down Expand Up @@ -198,6 +199,7 @@ def test_folder_structure(self, setup):
"sample_outputs_deepsparse",
]:
expected_files.update({file_name, file_name + ".tar.gz"})

assert not set(os.listdir(temp_dir.name)).difference(expected_files)

def test_validate(self, setup):
Expand All @@ -223,7 +225,7 @@ def _add_mock_files(directory_path: str, clone_sample_outputs: bool):
os.makedirs(onnx_folder_dir)
for opset in range(1, 3):
shutil.copyfile(
os.path.join(directory_path, "model.onnx"),
os.path.join(directory_path, "deployment", "model.onnx"),
os.path.join(onnx_folder_dir, f"model.{opset}.onnx"),
)

Expand Down
2 changes: 1 addition & 1 deletion tests/sparsezoo/objects/test_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_zipping_on_creation(self, setup):
) = setup
directory = Directory(name=name, files=files, path=path)
directory.gzip()
new_directory = Directory(name=directory.name, path=directory.path)
new_directory = Directory(name=directory.name, path=directory.path, force=True)
assert os.path.isdir(new_directory.path)
assert new_directory.path == directory.path.replace(".tar.gz", "")
assert new_directory.files
Expand Down