Skip to content

Commit

Permalink
Storage initializer fix so that it downloads only specific file when …
Browse files Browse the repository at this point in the history
…provided uri is not a folder (kubeflow#3088)

* fix issue kubeflow#2473

Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com>
Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>

* add test

Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com>
Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>

* flake8

Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com>
Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>

* update target_key logic and add test

Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com>
Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>

* update

Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com>
Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>

* Implemented to download exact s3 object.

Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>

* Fixed s3 storage test.

Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>

* Refactor code to make it more readable and reduce cognitive complexity

Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>

---------

Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com>
Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>
Co-authored-by: tjandy98 <3953059+tjandy98@users.noreply.github.com>
  • Loading branch information
andyi2it and tjandy98 authored Sep 8, 2023
1 parent f7de5e6 commit 29e3515
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 6 deletions.
24 changes: 19 additions & 5 deletions python/kserve/kserve/storage/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def _download_s3(uri, temp_dir: str):
bucket_path = parsed.path.lstrip('/')

count = 0
exact_obj_found = False
bucket = s3.Bucket(bucket_name)
for obj in bucket.objects.filter(Prefix=bucket_path):
# Skip where boto3 lists the directory as an object
Expand All @@ -191,17 +192,30 @@ def _download_s3(uri, temp_dir: str):
# If 'uri' is set to "s3://test-bucket/a/b/c/model.bin", then
# the downloader will add to temp dir: model.bin
# (without any subpaths).
target_key = (
obj.key.rsplit("/", 1)[-1]
if bucket_path == obj.key
else obj.key.replace(bucket_path, "", 1).lstrip("/")
)
# If the bucket path is s3://test/models
# Objects: churn, churn-pickle, churn-pickle-logs
bucket_path_last_part = bucket_path.split("/")[-1]
object_last_path = obj.key.split("/")[-1]
bucket_path_parent_part = bucket_path.rsplit("/", 1)[0]

if bucket_path == obj.key:
target_key = obj.key.rsplit("/", 1)[-1]
exact_obj_found = True
elif object_last_path.startswith(bucket_path_last_part):
target_key = obj.key.replace(bucket_path_parent_part, "", 1).lstrip("/")
else:
target_key = obj.key.replace(bucket_path, "").lstrip("/")

target = f"{temp_dir}/{target_key}"
if not os.path.exists(os.path.dirname(target)):
os.makedirs(os.path.dirname(target), exist_ok=True)
bucket.download_file(obj.key, target)
logging.info('Downloaded object %s to %s' % (obj.key, target))
count = count + 1

# If the exact object is found, then it is sufficient to download that and break the loop
if exact_obj_found:
break
if count == 0:
raise RuntimeError(
"Failed to fetch model. No model found in %s." % bucket_path)
Expand Down
40 changes: 39 additions & 1 deletion python/kserve/kserve/storage/test/test_s3_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from botocore.client import Config
from botocore import UNSIGNED

from kserve.storage import Storage

STORAGE_MODULE = 'kserve.storage.storage'
Expand Down Expand Up @@ -140,6 +139,45 @@ def test_full_name_key_root_bucket_dir(mock_storage):
"AWS_SESSION_TOKEN": "testing"}


@mock.patch(STORAGE_MODULE + '.boto3')
def test_multikey(mock_storage):
# given
bucket_name = 'foo'
paths = ['b/model.bin']
object_paths = ['test/a/' + p for p in paths]

# when
mock_boto3_bucket = create_mock_boto3_bucket(mock_storage, object_paths)
Storage._download_s3(f's3://{bucket_name}/test/a', 'dest_path')

# then
arg_list = get_call_args(mock_boto3_bucket.download_file.call_args_list)
assert arg_list == expected_call_args_list('test/a', 'dest_path', paths)

mock_boto3_bucket.objects.filter.assert_called_with(Prefix='test/a')


@mock.patch(STORAGE_MODULE + '.boto3')
def test_files_with_no_extension(mock_storage):

# given
bucket_name = 'foo'
paths = ['churn-pickle', 'churn-pickle-logs', 'churn-pickle-report']
object_paths = ['test/' + p for p in paths]

# when
mock_boto3_bucket = create_mock_boto3_bucket(mock_storage, object_paths)
Storage._download_s3(f's3://{bucket_name}/test/churn-pickle', 'dest_path')

# then
arg_list = get_call_args(mock_boto3_bucket.download_file.call_args_list)

# Download only the exact file if found; otherwise, download all files with the given prefix
assert arg_list[0] == expected_call_args_list('test', 'dest_path', paths)[0]

mock_boto3_bucket.objects.filter.assert_called_with(Prefix='test/churn-pickle')


def test_get_S3_config():
DEFAULT_CONFIG = Config()
ANON_CONFIG = Config(signature_version=UNSIGNED)
Expand Down

0 comments on commit 29e3515

Please sign in to comment.