Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use boto3 session in S3Upload and S3Download tasks #3981

Merged
merged 10 commits into from
Feb 1, 2021
5 changes: 5 additions & 0 deletions changes/pr3981.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
task:
- "Use boto3 session in `S3Upload` and `S3Download` tasks, to ensure thread-safe execution - [#3981](https://github.com/PrefectHQ/prefect/pull/3981)"

contributor:
- "[Gregory Roche](https://github.com/gregoryroche)"
8 changes: 6 additions & 2 deletions src/prefect/tasks/aws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def run(
if bucket is None:
raise ValueError("A bucket name must be provided.")

s3_client = get_boto_client("s3", credentials=credentials, **self.boto_kwargs)
s3_client = get_boto_client(
"s3", credentials=credentials, use_session=True, **self.boto_kwargs
)

stream = io.BytesIO()

Expand Down Expand Up @@ -144,7 +146,9 @@ def run(
if bucket is None:
raise ValueError("A bucket name must be provided.")

s3_client = get_boto_client("s3", credentials=credentials, **self.boto_kwargs)
s3_client = get_boto_client(
"s3", credentials=credentials, use_session=True, **self.boto_kwargs
)

# compress data if compression is specified
if compression:
Expand Down
60 changes: 33 additions & 27 deletions tests/tasks/aws/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
from prefect.utilities.configuration import set_temporary_config


@pytest.fixture
def mocked_boto_client(monkeypatch):
boto3 = MagicMock()
client = boto3.session.Session().client()
boto3.client = MagicMock(return_value=client)
monkeypatch.setattr("prefect.utilities.aws.boto3", boto3)
return client


class TestS3Download:
def test_initialization(self):
task = S3Download()
Expand All @@ -22,31 +31,31 @@ def test_raises_if_bucket_not_eventually_provided(self):
with pytest.raises(ValueError, match="bucket"):
task.run(key="")

def test_gzip_compression(self, monkeypatch):
def test_gzip_compression(self, mocked_boto_client):
task = S3Download("bucket")
byte_string = b"col1,col2,col3\nfake,data,1\nfalse,data,2\n"
gzip_data = gzip.compress(byte_string)

def modify_stream(Bucket=None, Key=None, Fileobj=None):
Fileobj.write(gzip_data)

client = MagicMock()
boto3 = MagicMock(client=MagicMock(return_value=client))
monkeypatch.setattr("prefect.utilities.aws.boto3", boto3)
client.download_fileobj.side_effect = modify_stream

mocked_boto_client.download_fileobj.side_effect = modify_stream
returned_data = task.run("key", compression="gzip")
assert returned_data == str(byte_string, "utf-8")

def test_raises_on_invalid_compression_method(self, monkeypatch):
def test_raises_on_invalid_compression_method(self, mocked_boto_client):
task = S3Download("test")
client = MagicMock()
boto3 = MagicMock(client=MagicMock(return_value=client))
monkeypatch.setattr("prefect.utilities.aws.boto3", boto3)

with pytest.raises(ValueError, match="gz_fake"):
task.run("key", compression="gz_fake")

def test_boto3_client_is_created_with_session(self, mocked_boto_client):
""" Tests the fix for #3925 """
task = S3Download("test")
result = task.run("key")
assert (
"session.Session()" in mocked_boto_client.return_value._extract_mock_name()
)


class TestS3Upload:
def test_initialization(self):
Expand All @@ -62,42 +71,39 @@ def test_raises_if_bucket_not_eventually_provided(self):
with pytest.raises(ValueError, match="bucket"):
task.run(data="")

def test_generated_key_is_str(self, monkeypatch):
def test_generated_key_is_str(self, mocked_boto_client):
task = S3Upload(bucket="test")
client = MagicMock()
boto3 = MagicMock(client=MagicMock(return_value=client))
monkeypatch.setattr("prefect.utilities.aws.boto3", boto3)
with set_temporary_config({"cloud.use_local_secrets": True}):
with prefect.context(
secrets=dict(
AWS_CREDENTIALS={"ACCESS_KEY": "42", "SECRET_ACCESS_KEY": "99"}
)
):
task.run(data="")
assert type(client.upload_fileobj.call_args[1]["Key"]) == str
assert type(mocked_boto_client.upload_fileobj.call_args[1]["Key"]) == str

def test_gzip_compression(self, monkeypatch):
def test_gzip_compression(self, mocked_boto_client):
task = S3Upload("bucket")
byte_string = b"col1,col2,col3\nfake,data,1\nfalse,info,2\n"

client = MagicMock()
boto3 = MagicMock(client=MagicMock(return_value=client))
monkeypatch.setattr("prefect.utilities.aws.boto3", boto3)

task.run(byte_string, key="key", compression="gzip")
args, kwargs = client.upload_fileobj.call_args_list[0]
args, kwargs = mocked_boto_client.upload_fileobj.call_args_list[0]
gzip_data_stream = args[0]
assert gzip.decompress(gzip_data_stream.read()) == byte_string

def test_raises_on_invalid_compression_method(self, monkeypatch):
def test_raises_on_invalid_compression_method(self, mocked_boto_client):
task = S3Upload("test")
client = MagicMock()
boto3 = MagicMock(client=MagicMock(return_value=client))
monkeypatch.setattr("prefect.utilities.aws.boto3", boto3)

with pytest.raises(ValueError, match="gz_fake"):
task.run(b"data", compression="gz_fake")

def test_boto3_client_is_created_with_session(self, mocked_boto_client):
""" Tests the fix for #3925 """
task = S3Upload("test")
result = task.run("key")
assert (
"session.Session()" in mocked_boto_client.return_value._extract_mock_name()
)


class TestS3List:
def test_initialization(self):
Expand Down