Skip to content

Commit

Permalink
adding tgz/zip support for all storages (kubeflow#1836)
Browse files Browse the repository at this point in the history
* adding tgz/zip support for all storages

* Fix lint

Co-authored-by: Dan Sun <dsun20@bloomberg.net>
  • Loading branch information
Suresh-Nakkeran and yuzisun authored Oct 5, 2021
1 parent 0905b47 commit 3dfff84
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 6 deletions.
53 changes: 47 additions & 6 deletions python/kserve/kserve/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _download_s3(uri, temp_dir: str):
bucket_name = parsed.netloc
bucket_path = parsed.path.lstrip('/')

count = 0
bucket = s3.Bucket(bucket_name)
for obj in bucket.objects.filter(Prefix=bucket_path):
# Skip where boto3 lists the directory as an object
Expand All @@ -125,6 +126,13 @@ def _download_s3(uri, temp_dir: str):
if not os.path.exists(os.path.dirname(target)):
os.makedirs(os.path.dirname(target), exist_ok=True)
bucket.download_file(obj.key, target)
count = count + 1

# Unpack compressed file, supports .tgz, tar.gz and zip file formats.
if count == 1:
mimetype, _ = mimetypes.guess_type(target)
if mimetype in ["application/x-tar", "application/zip"]:
Storage._unpack_archive_file(target, mimetype, temp_dir)

@staticmethod
def _download_gcs(uri, temp_dir: str):
Expand Down Expand Up @@ -159,6 +167,12 @@ def _download_gcs(uri, temp_dir: str):
raise RuntimeError("Failed to fetch model. \
The path or model %s does not exist." % uri)

# Unpack compressed file, supports .tgz, tar.gz and zip file formats.
if count == 1:
mimetype, _ = mimetypes.guess_type(blob.name)
if mimetype in ["application/x-tar", "application/zip"]:
Storage._unpack_archive_file(dest_path, mimetype, temp_dir)

@staticmethod
def _download_blob(uri, out_dir: str): # pylint: disable=too-many-locals
match = re.search(_BLOB_RE, uri)
Expand Down Expand Up @@ -203,6 +217,12 @@ def _download_blob(uri, out_dir: str): # pylint: disable=too-many-locals
raise RuntimeError("Failed to fetch model. \
The path or model %s does not exist." % (uri))

# Unpack compressed file, supports .tgz, tar.gz and zip file formats.
if count == 1:
mimetype, _ = mimetypes.guess_type(dest_path)
if mimetype in ["application/x-tar", "application/zip"]:
Storage._unpack_archive_file(dest_path, mimetype, out_dir)

@staticmethod
def _get_azure_storage_token():
tenant_id = os.getenv("AZ_TENANT_ID", "")
Expand Down Expand Up @@ -236,11 +256,20 @@ def _download_local(uri, out_dir=None):
if os.path.isdir(local_path):
local_path = os.path.join(local_path, "*")

count = 0
for src in glob.glob(local_path):
_, tail = os.path.split(src)
dest_path = os.path.join(out_dir, tail)
logging.info("Linking: %s to %s", src, dest_path)
os.symlink(src, dest_path)
count = count + 1

# Unpack compressed file, supports .tgz, tar.gz and zip file formats.
if count == 1:
mimetype, _ = mimetypes.guess_type(dest_path)
if mimetype in ["application/x-tar", "application/zip"]:
Storage._unpack_archive_file(dest_path, mimetype, out_dir)

return out_dir

@staticmethod
Expand Down Expand Up @@ -287,12 +316,24 @@ def _download_from_uri(uri, out_dir=None):
shutil.copyfileobj(stream, out)

if mimetype in ["application/x-tar", "application/zip"]:
Storage._unpack_archive_file(local_path, mimetype, out_dir)

return out_dir

@staticmethod
def _unpack_archive_file(file_path, mimetype, target_dir=None):
if not target_dir:
target_dir = os.path.dirname(file_path)

try:
logging.info("Unpacking: %s", file_path)
if mimetype == "application/x-tar":
archive = tarfile.open(local_path, 'r', encoding='utf-8')
archive = tarfile.open(file_path, 'r', encoding='utf-8')
else:
archive = zipfile.ZipFile(local_path, 'r')
archive.extractall(out_dir)
archive = zipfile.ZipFile(file_path, 'r')
archive.extractall(target_dir)
archive.close()
os.remove(local_path)

return out_dir
except (tarfile.TarError, zipfile.BadZipfile):
raise RuntimeError("Failed to unpack archieve file. \
The file format is not valid.")
os.remove(file_path)
22 changes: 22 additions & 0 deletions python/kserve/test/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import tempfile
import binascii
import unittest.mock as mock
import mimetypes
from pathlib import Path

import botocore
import kserve
Expand Down Expand Up @@ -174,3 +176,23 @@ def test_no_permission_buckets(mock_connection, mock_boto3):

with pytest.raises(botocore.exceptions.ClientError):
kserve.Storage.download(bad_s3_path)


def test_unpack_tar_file():
out_dir = '.'
tar_file = os.path.join(out_dir, "model.tgz")
Path(tar_file).write_bytes(FILE_TAR_GZ_RAW)
mimetype, _ = mimetypes.guess_type(tar_file)
kserve.Storage._unpack_archive_file(tar_file, mimetype, out_dir)
assert os.path.exists(os.path.join(out_dir, 'model.pth'))
os.remove(os.path.join(out_dir, 'model.pth'))


def test_unpack_zip_file():
out_dir = '.'
tar_file = os.path.join(out_dir, "model.zip")
Path(tar_file).write_bytes(FILE_ZIP_RAW)
mimetype, _ = mimetypes.guess_type(tar_file)
kserve.Storage._unpack_archive_file(tar_file, mimetype, out_dir)
assert os.path.exists(os.path.join(out_dir, 'model.pth'))
os.remove(os.path.join(out_dir, 'model.pth'))

0 comments on commit 3dfff84

Please sign in to comment.