Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix s3 session creation in deployment steps push_to_s3 and pull_from_s3 #322

Merged
merged 5 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Changed `push_to_s3` deployment step function to write paths `as_posix()` to allow support for deploying from windows [#314](https://github.com/PrefectHQ/prefect-aws/pull/314)
- Changed `push_to_s3` and `pull_from_s3` deployment steps to properly create a boto3 session client if the passed credentials are a referenced AwsCredentials block [#322](https://github.com/PrefectHQ/prefect-aws/pull/322)

### Fixed

Expand Down
68 changes: 51 additions & 17 deletions prefect_aws/deployments/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,7 @@ def push_to_s3(
```
"""
if credentials is None:
credentials = {}
if client_parameters is None:
client_parameters = {}
advanced_config = client_parameters.pop("config", {})
client = boto3.client(
"s3", **credentials, **client_parameters, config=Config(**advanced_config)
)
s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters)

local_path = Path.cwd()

Expand All @@ -117,7 +110,7 @@ def push_to_s3(
continue
elif not local_file_path.is_dir():
remote_file_path = Path(folder) / local_file_path.relative_to(local_path)
client.upload_file(
s3.upload_file(
str(local_file_path), bucket, str(remote_file_path.as_posix())
)

Expand Down Expand Up @@ -174,14 +167,7 @@ def pull_from_s3(
credentials: "{{ prefect.blocks.aws-credentials.dev-credentials }}"
```
"""
if credentials is None:
credentials = {}
if client_parameters is None:
client_parameters = {}
advanced_config = client_parameters.pop("config", {})

session = boto3.Session(**credentials)
s3 = session.client("s3", **client_parameters, config=Config(**advanced_config))
s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters)

local_path = Path.cwd()

Expand All @@ -206,3 +192,51 @@ def pull_from_s3(
"folder": folder,
"directory": str(local_path),
}


def get_s3_client(
credentials: Optional[Dict] = None,
client_parameters: Optional[Dict] = None,
) -> dict:
if credentials is None:
credentials = {}
if client_parameters is None:
client_parameters = {}

# Get credentials from credentials (regardless if block or not)
aws_access_key_id = credentials.get("aws_access_key_id", None)
aws_secret_access_key = credentials.get("aws_secret_access_key", None)
aws_session_token = credentials.get("aws_session_token", None)

# Get remaining session info from credentials, or client_parameters
profile_name = credentials.get(
"profile_name", client_parameters.get("profile_name", None)
)
region_name = credentials.get(
"region_name", client_parameters.get("region_name", None)
)

# Get additional info from client_parameters, otherwise credentials input (if block)
aws_client_parameters = credentials.get("aws_client_parameters", client_parameters)
api_version = aws_client_parameters.get("api_version", None)
endpoint_url = aws_client_parameters.get("endpoint_url", None)
use_ssl = aws_client_parameters.get("use_ssl", None)
verify = aws_client_parameters.get("verify", None)
config_params = aws_client_parameters.get("config", {})
config = Config(**config_params)

session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
profile_name=profile_name,
region_name=region_name,
)
return session.client(
"s3",
api_version=api_version,
endpoint_url=endpoint_url,
use_ssl=use_ssl,
verify=verify,
config=config,
)
96 changes: 95 additions & 1 deletion tests/deploments/test_steps.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import sys
from pathlib import Path, PurePath, PurePosixPath
from unittest.mock import patch

import boto3
import pytest
from moto import mock_s3

from prefect_aws.deployments.steps import pull_from_s3, push_to_s3
from prefect_aws import AwsCredentials
from prefect_aws.deployments.steps import get_s3_client, pull_from_s3, push_to_s3


@pytest.fixture
Expand Down Expand Up @@ -173,6 +175,98 @@ def test_push_pull_empty_folders(s3_setup, tmp_path, mock_aws_credentials):
assert not (tmp_path / "empty2_copy").exists()


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires Python 3.8+")
def test_s3_session_with_params():
with patch("boto3.Session") as mock_session:
get_s3_client(
credentials={
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
"profile_name": "foo",
"region_name": "us-weast-1",
"aws_client_parameters": {
"api_version": "v1",
"config": {"connect_timeout": 300},
},
}
)
get_s3_client(
credentials={
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
},
client_parameters={
"region_name": "us-west-1",
"config": {"signature_version": "s3v4"},
},
)
creds_block = AwsCredentials(
aws_access_key_id="BlockKey",
aws_secret_access_key="BlockSecret",
aws_session_token="BlockToken",
profile_name="BlockProfile",
region_name="BlockRegion",
aws_client_parameters={
"api_version": "v1",
"use_ssl": True,
"verify": True,
"endpoint_url": "BlockEndpoint",
"config": {"connect_timeout": 123},
},
)
get_s3_client(credentials=creds_block.dict())
all_calls = mock_session.mock_calls
assert len(all_calls) == 6
assert all_calls[0].kwargs == {
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
"aws_session_token": None,
"profile_name": "foo",
"region_name": "us-weast-1",
}
assert all_calls[1].args[0] == "s3"
assert {
"api_version": "v1",
"endpoint_url": None,
"use_ssl": None,
"verify": None,
}.items() <= all_calls[1].kwargs.items()
assert all_calls[1].kwargs.get("config").connect_timeout == 300
assert all_calls[1].kwargs.get("config").signature_version is None
assert all_calls[2].kwargs == {
"aws_access_key_id": "THE_KEY",
"aws_secret_access_key": "SHHH!",
"aws_session_token": None,
"profile_name": None,
"region_name": "us-west-1",
}
assert all_calls[3].args[0] == "s3"
assert {
"api_version": None,
"endpoint_url": None,
"use_ssl": None,
"verify": None,
}.items() <= all_calls[3].kwargs.items()
assert all_calls[3].kwargs.get("config").connect_timeout == 60
assert all_calls[3].kwargs.get("config").signature_version == "s3v4"
assert all_calls[4].kwargs == {
"aws_access_key_id": "BlockKey",
"aws_secret_access_key": creds_block.aws_secret_access_key,
"aws_session_token": "BlockToken",
"profile_name": "BlockProfile",
"region_name": "BlockRegion",
}
assert all_calls[5].args[0] == "s3"
assert {
"api_version": "v1",
"use_ssl": True,
"verify": True,
"endpoint_url": "BlockEndpoint",
}.items() <= all_calls[5].kwargs.items()
assert all_calls[5].kwargs.get("config").connect_timeout == 123
assert all_calls[5].kwargs.get("config").signature_version is None


def test_custom_credentials_and_client_parameters(s3_setup, tmp_files):
s3, bucket_name = s3_setup
folder = "my-project"
Expand Down