diff --git a/python/kserve/kserve/storage.py b/python/kserve/kserve/storage.py index f011f8aae80..848f8d88c7f 100644 --- a/python/kserve/kserve/storage.py +++ b/python/kserve/kserve/storage.py @@ -109,6 +109,7 @@ def _download_s3(uri, temp_dir: str): bucket_name = parsed.netloc bucket_path = parsed.path.lstrip('/') + count = 0 bucket = s3.Bucket(bucket_name) for obj in bucket.objects.filter(Prefix=bucket_path): # Skip where boto3 lists the directory as an object @@ -125,6 +126,13 @@ def _download_s3(uri, temp_dir: str): if not os.path.exists(os.path.dirname(target)): os.makedirs(os.path.dirname(target), exist_ok=True) bucket.download_file(obj.key, target) + count = count + 1 + + # Unpack compressed file, supports .tgz, tar.gz and zip file formats. + if count == 1: + mimetype, _ = mimetypes.guess_type(target) + if mimetype in ["application/x-tar", "application/zip"]: + Storage._unpack_archive_file(target, mimetype, temp_dir) @staticmethod def _download_gcs(uri, temp_dir: str): @@ -159,6 +167,12 @@ def _download_gcs(uri, temp_dir: str): raise RuntimeError("Failed to fetch model. \ The path or model %s does not exist." % uri) + # Unpack compressed file, supports .tgz, tar.gz and zip file formats. + if count == 1: + mimetype, _ = mimetypes.guess_type(blob.name) + if mimetype in ["application/x-tar", "application/zip"]: + Storage._unpack_archive_file(dest_path, mimetype, temp_dir) + @staticmethod def _download_blob(uri, out_dir: str): # pylint: disable=too-many-locals match = re.search(_BLOB_RE, uri) @@ -203,6 +217,12 @@ def _download_blob(uri, out_dir: str): # pylint: disable=too-many-locals raise RuntimeError("Failed to fetch model. \ The path or model %s does not exist." % (uri)) + # Unpack compressed file, supports .tgz, tar.gz and zip file formats. + if count == 1: + mimetype, _ = mimetypes.guess_type(dest_path) + if mimetype in ["application/x-tar", "application/zip"]: + Storage._unpack_archive_file(dest_path, mimetype, out_dir) + @staticmethod def _get_azure_storage_token(): tenant_id = os.getenv("AZ_TENANT_ID", "") @@ -236,11 +256,20 @@ def _download_local(uri, out_dir=None): if os.path.isdir(local_path): local_path = os.path.join(local_path, "*") + count = 0 for src in glob.glob(local_path): _, tail = os.path.split(src) dest_path = os.path.join(out_dir, tail) logging.info("Linking: %s to %s", src, dest_path) os.symlink(src, dest_path) + count = count + 1 + + # Unpack compressed file, supports .tgz, tar.gz and zip file formats. + if count == 1: + mimetype, _ = mimetypes.guess_type(dest_path) + if mimetype in ["application/x-tar", "application/zip"]: + Storage._unpack_archive_file(dest_path, mimetype, out_dir) + return out_dir @staticmethod @@ -287,12 +316,24 @@ def _download_from_uri(uri, out_dir=None): shutil.copyfileobj(stream, out) if mimetype in ["application/x-tar", "application/zip"]: + Storage._unpack_archive_file(local_path, mimetype, out_dir) + + return out_dir + + @staticmethod + def _unpack_archive_file(file_path, mimetype, target_dir=None): + if not target_dir: + target_dir = os.path.dirname(file_path) + + try: + logging.info("Unpacking: %s", file_path) if mimetype == "application/x-tar": - archive = tarfile.open(local_path, 'r', encoding='utf-8') + archive = tarfile.open(file_path, 'r', encoding='utf-8') else: - archive = zipfile.ZipFile(local_path, 'r') - archive.extractall(out_dir) + archive = zipfile.ZipFile(file_path, 'r') + archive.extractall(target_dir) archive.close() - os.remove(local_path) - - return out_dir + except (tarfile.TarError, zipfile.BadZipfile): + raise RuntimeError("Failed to unpack archieve file. \ +The file format is not valid.") + os.remove(file_path) diff --git a/python/kserve/test/test_storage.py b/python/kserve/test/test_storage.py index 3716dfc5710..af3a9a0addb 100644 --- a/python/kserve/test/test_storage.py +++ b/python/kserve/test/test_storage.py @@ -16,6 +16,8 @@ import tempfile import binascii import unittest.mock as mock +import mimetypes +from pathlib import Path import botocore import kserve @@ -174,3 +176,23 @@ def test_no_permission_buckets(mock_connection, mock_boto3): with pytest.raises(botocore.exceptions.ClientError): kserve.Storage.download(bad_s3_path) + + +def test_unpack_tar_file(): + out_dir = '.' + tar_file = os.path.join(out_dir, "model.tgz") + Path(tar_file).write_bytes(FILE_TAR_GZ_RAW) + mimetype, _ = mimetypes.guess_type(tar_file) + kserve.Storage._unpack_archive_file(tar_file, mimetype, out_dir) + assert os.path.exists(os.path.join(out_dir, 'model.pth')) + os.remove(os.path.join(out_dir, 'model.pth')) + + +def test_unpack_zip_file(): + out_dir = '.' + tar_file = os.path.join(out_dir, "model.zip") + Path(tar_file).write_bytes(FILE_ZIP_RAW) + mimetype, _ = mimetypes.guess_type(tar_file) + kserve.Storage._unpack_archive_file(tar_file, mimetype, out_dir) + assert os.path.exists(os.path.join(out_dir, 'model.pth')) + os.remove(os.path.join(out_dir, 'model.pth'))