diff --git a/CHANGELOG.md b/CHANGELOG.md index c73c15ff..a310bab3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Changed `push_to_s3` deployment step function to write paths `as_posix()` to allow support for deploying from windows [#314](https://github.com/PrefectHQ/prefect-aws/pull/314) +- Changed `push_to_s3` and `pull_from_s3` deployment steps to properly create a boto3 session client if the passed credentials are a referenced AwsCredentials block [#322](https://github.com/PrefectHQ/prefect-aws/pull/322) ### Fixed diff --git a/prefect_aws/deployments/steps.py b/prefect_aws/deployments/steps.py index 5609a044..77930465 100644 --- a/prefect_aws/deployments/steps.py +++ b/prefect_aws/deployments/steps.py @@ -91,14 +91,7 @@ def push_to_s3( ``` """ - if credentials is None: - credentials = {} - if client_parameters is None: - client_parameters = {} - advanced_config = client_parameters.pop("config", {}) - client = boto3.client( - "s3", **credentials, **client_parameters, config=Config(**advanced_config) - ) + s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters) local_path = Path.cwd() @@ -117,7 +110,7 @@ def push_to_s3( continue elif not local_file_path.is_dir(): remote_file_path = Path(folder) / local_file_path.relative_to(local_path) - client.upload_file( + s3.upload_file( str(local_file_path), bucket, str(remote_file_path.as_posix()) ) @@ -174,14 +167,7 @@ def pull_from_s3( credentials: "{{ prefect.blocks.aws-credentials.dev-credentials }}" ``` """ - if credentials is None: - credentials = {} - if client_parameters is None: - client_parameters = {} - advanced_config = client_parameters.pop("config", {}) - - session = boto3.Session(**credentials) - s3 = session.client("s3", **client_parameters, config=Config(**advanced_config)) + s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters) local_path = Path.cwd() @@ -206,3 +192,51 @@ def pull_from_s3( "folder": folder, "directory": str(local_path), } + + +def get_s3_client( + credentials: Optional[Dict] = None, + client_parameters: Optional[Dict] = None, +) -> dict: + if credentials is None: + credentials = {} + if client_parameters is None: + client_parameters = {} + + # Get credentials from credentials (regardless if block or not) + aws_access_key_id = credentials.get("aws_access_key_id", None) + aws_secret_access_key = credentials.get("aws_secret_access_key", None) + aws_session_token = credentials.get("aws_session_token", None) + + # Get remaining session info from credentials, or client_parameters + profile_name = credentials.get( + "profile_name", client_parameters.get("profile_name", None) + ) + region_name = credentials.get( + "region_name", client_parameters.get("region_name", None) + ) + + # Get additional info from client_parameters, otherwise credentials input (if block) + aws_client_parameters = credentials.get("aws_client_parameters", client_parameters) + api_version = aws_client_parameters.get("api_version", None) + endpoint_url = aws_client_parameters.get("endpoint_url", None) + use_ssl = aws_client_parameters.get("use_ssl", None) + verify = aws_client_parameters.get("verify", None) + config_params = aws_client_parameters.get("config", {}) + config = Config(**config_params) + + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + profile_name=profile_name, + region_name=region_name, + ) + return session.client( + "s3", + api_version=api_version, + endpoint_url=endpoint_url, + use_ssl=use_ssl, + verify=verify, + config=config, + ) diff --git a/tests/deploments/test_steps.py b/tests/deploments/test_steps.py index a2312d18..c78c3578 100644 --- a/tests/deploments/test_steps.py +++ b/tests/deploments/test_steps.py @@ -1,12 +1,14 @@ import os import sys from pathlib import Path, PurePath, PurePosixPath +from unittest.mock import patch import boto3 import pytest from moto import mock_s3 -from prefect_aws.deployments.steps import pull_from_s3, push_to_s3 +from prefect_aws import AwsCredentials +from prefect_aws.deployments.steps import get_s3_client, pull_from_s3, push_to_s3 @pytest.fixture @@ -173,6 +175,98 @@ def test_push_pull_empty_folders(s3_setup, tmp_path, mock_aws_credentials): assert not (tmp_path / "empty2_copy").exists() +@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires Python 3.8+") +def test_s3_session_with_params(): + with patch("boto3.Session") as mock_session: + get_s3_client( + credentials={ + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "profile_name": "foo", + "region_name": "us-weast-1", + "aws_client_parameters": { + "api_version": "v1", + "config": {"connect_timeout": 300}, + }, + } + ) + get_s3_client( + credentials={ + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + }, + client_parameters={ + "region_name": "us-west-1", + "config": {"signature_version": "s3v4"}, + }, + ) + creds_block = AwsCredentials( + aws_access_key_id="BlockKey", + aws_secret_access_key="BlockSecret", + aws_session_token="BlockToken", + profile_name="BlockProfile", + region_name="BlockRegion", + aws_client_parameters={ + "api_version": "v1", + "use_ssl": True, + "verify": True, + "endpoint_url": "BlockEndpoint", + "config": {"connect_timeout": 123}, + }, + ) + get_s3_client(credentials=creds_block.dict()) + all_calls = mock_session.mock_calls + assert len(all_calls) == 6 + assert all_calls[0].kwargs == { + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "aws_session_token": None, + "profile_name": "foo", + "region_name": "us-weast-1", + } + assert all_calls[1].args[0] == "s3" + assert { + "api_version": "v1", + "endpoint_url": None, + "use_ssl": None, + "verify": None, + }.items() <= all_calls[1].kwargs.items() + assert all_calls[1].kwargs.get("config").connect_timeout == 300 + assert all_calls[1].kwargs.get("config").signature_version is None + assert all_calls[2].kwargs == { + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "aws_session_token": None, + "profile_name": None, + "region_name": "us-west-1", + } + assert all_calls[3].args[0] == "s3" + assert { + "api_version": None, + "endpoint_url": None, + "use_ssl": None, + "verify": None, + }.items() <= all_calls[3].kwargs.items() + assert all_calls[3].kwargs.get("config").connect_timeout == 60 + assert all_calls[3].kwargs.get("config").signature_version == "s3v4" + assert all_calls[4].kwargs == { + "aws_access_key_id": "BlockKey", + "aws_secret_access_key": creds_block.aws_secret_access_key, + "aws_session_token": "BlockToken", + "profile_name": "BlockProfile", + "region_name": "BlockRegion", + } + assert all_calls[5].args[0] == "s3" + assert { + "api_version": "v1", + "use_ssl": True, + "verify": True, + "endpoint_url": "BlockEndpoint", + }.items() <= all_calls[5].kwargs.items() + assert all_calls[5].kwargs.get("config").connect_timeout == 123 + assert all_calls[5].kwargs.get("config").signature_version is None + + def test_custom_credentials_and_client_parameters(s3_setup, tmp_files): s3, bucket_name = s3_setup folder = "my-project"