Skip to content

Commit

Permalink
Added credentials option for iter_bucket. (#372)
Browse files Browse the repository at this point in the history
* Added credentials option for iter_bucket.

This closes issue #259

* Applied PR Feedback

* Formatting changes

* explicitly construct session with session_kwargs

* fix bug, update unit tests

* remove test with tricky mocks

* isolate credentials test

Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
derpferd and mpenkov authored Mar 11, 2020
1 parent f729054 commit 68a39d9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 8 deletions.
40 changes: 32 additions & 8 deletions smart_open/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,14 @@ def _accept_all(key):
return True


def iter_bucket(bucket_name, prefix='', accept_key=None,
key_limit=None, workers=16, retries=3):
def iter_bucket(
bucket_name,
prefix='',
accept_key=None,
key_limit=None,
workers=16,
retries=3,
**session_kwargs):
"""
Iterate and download all S3 objects under `s3://bucket_name/prefix`.
Expand All @@ -676,6 +682,11 @@ def iter_bucket(bucket_name, prefix='', accept_key=None,
The number of subprocesses to use.
retries: int, optional
The number of time to retry a failed download.
session_kwargs: dict, optional
Keyword arguments to pass when creating a new session.
For a list of available names and values, see:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session
Yields
------
Expand Down Expand Up @@ -716,8 +727,16 @@ def iter_bucket(bucket_name, prefix='', accept_key=None,
pass

total_size, key_no = 0, -1
key_iterator = _list_bucket(bucket_name, prefix=prefix, accept_key=accept_key)
download_key = functools.partial(_download_key, bucket_name=bucket_name, retries=retries)
key_iterator = _list_bucket(
bucket_name,
prefix=prefix,
accept_key=accept_key,
**session_kwargs)
download_key = functools.partial(
_download_key,
bucket_name=bucket_name,
retries=retries,
**session_kwargs)

with _create_process_pool(processes=workers) as pool:
result_iterator = pool.imap_unordered(download_key, key_iterator)
Expand All @@ -736,8 +755,13 @@ def iter_bucket(bucket_name, prefix='', accept_key=None,
logger.info("processed %i keys, total size %i" % (key_no + 1, total_size))


def _list_bucket(bucket_name, prefix='', accept_key=lambda k: True):
client = boto3.client('s3')
def _list_bucket(
bucket_name,
prefix='',
accept_key=lambda k: True,
**session_kwargs):
session = boto3.session.Session(**session_kwargs)
client = session.client('s3')
ctoken = None

while True:
Expand All @@ -762,14 +786,14 @@ def _list_bucket(bucket_name, prefix='', accept_key=lambda k: True):
break


def _download_key(key_name, bucket_name=None, retries=3):
def _download_key(key_name, bucket_name=None, retries=3, **session_kwargs):
if bucket_name is None:
raise ValueError('bucket_name may not be None')

#
# https://geekpete.com/blog/multithreading-boto3/
#
session = boto3.session.Session()
session = boto3.session.Session(**session_kwargs)
s3 = session.resource('s3')
bucket = s3.Bucket(bucket_name)

Expand Down
21 changes: 21 additions & 0 deletions smart_open/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,27 @@ def test(self):
self.assertEqual(sorted(keys), sorted(expected))


#
# This has to be a separate test because we cannot run it against real S3
# (we don't want to expose our real S3 credentials).
#
@moto.mock_s3
class IterBucketCredentialsTest(unittest.TestCase):

def test(self):
num_keys = 10
populate_bucket(num_keys=num_keys)
result = list(
smart_open.s3.iter_bucket(
BUCKET_NAME,
workers=None,
aws_access_key_id='access_id',
aws_secret_access_key='access_secret'
)
)
self.assertEqual(len(result), num_keys)


@maybe_mock_s3
class DownloadKeyTest(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 68a39d9

Please sign in to comment.