diff --git a/CHANGELOG.md b/CHANGELOG.md index 35f0fdc3..1445d22c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Use cached boto3 clients in `ECSWorker` - [#375](https://github.com/PrefectHQ/prefect-aws/pull/375) + ### Fixed ### Deprecated @@ -18,15 +20,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed ## 0.4.9 -Released January 24rd, 2024; + +Released January 24rd, 2024. ### Fixed - Hashing of nested objects within `AwsClientParameters` - [#373](https://github.com/PrefectHQ/prefect-aws/pull/373) - ## 0.4.8 -Released January 23rd, 2024; + +Released January 23rd, 2024. ### Added diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 6d3c8ec7..372e6b29 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -51,12 +51,11 @@ import sys import time from copy import deepcopy -from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union from uuid import UUID import anyio import anyio.abc -import boto3 import yaml from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound from prefect.server.schemas.core import FlowRun @@ -79,7 +78,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed, wait_random from typing_extensions import Literal -from prefect_aws import AwsCredentials +from prefect_aws.credentials import AwsCredentials, ClientType # Internal type alias for ECS clients which are generated dynamically in botocore _ECSClient = Any @@ -584,8 +583,8 @@ async def run( """ Runs a given flow run on the current worker. """ - boto_session, ecs_client = await run_sync_in_worker_thread( - self._get_session_and_client, configuration + ecs_client = await run_sync_in_worker_thread( + self._get_client, configuration, "ecs" ) logger = self.get_flow_run_logger(flow_run) @@ -598,7 +597,6 @@ async def run( ) = await run_sync_in_worker_thread( self._create_task_and_wait_for_start, logger, - boto_session, ecs_client, configuration, flow_run, @@ -625,7 +623,6 @@ async def run( cluster_arn, task_definition, is_new_task_definition and configuration.auto_deregister_task_definition, - boto_session, ecs_client, ) @@ -636,21 +633,17 @@ async def run( status_code=status_code if status_code is not None else -1, ) - def _get_session_and_client( - self, - configuration: ECSJobConfiguration, - ) -> Tuple[boto3.Session, _ECSClient]: + def _get_client( + self, configuration: ECSJobConfiguration, client_type: Union[str, ClientType] + ) -> _ECSClient: """ - Retrieve a boto3 session and ECS client + Get a boto3 client of client_type. Will use a cached client if one exists. """ - boto_session = configuration.aws_credentials.get_boto3_session() - ecs_client = boto_session.client("ecs") - return boto_session, ecs_client + return configuration.aws_credentials.get_client(client_type) def _create_task_and_wait_for_start( self, logger: logging.Logger, - boto_session: boto3.Session, ecs_client: _ECSClient, configuration: ECSJobConfiguration, flow_run: FlowRun, @@ -741,7 +734,6 @@ def _create_task_and_wait_for_start( # Prepare the task run request task_run_request = self._prepare_task_run_request( - boto_session, configuration, task_definition, task_definition_arn, @@ -782,7 +774,6 @@ def _watch_task_and_get_exit_code( cluster_arn: str, task_definition: dict, deregister_task_definition: bool, - boto_session: boto3.Session, ecs_client: _ECSClient, ) -> Optional[int]: """ @@ -798,7 +789,6 @@ def _watch_task_and_get_exit_code( cluster_arn, task_definition, ecs_client, - boto_session, ) if deregister_task_definition: @@ -992,7 +982,6 @@ def _wait_for_task_finish( cluster_arn: str, task_definition: dict, ecs_client: _ECSClient, - boto_session: boto3.Session, ): """ Watch an ECS task until it reaches a STOPPED status. @@ -1031,7 +1020,7 @@ def _wait_for_task_finish( else: # Prepare to stream the output log_config = container_def["logConfiguration"]["options"] - logs_client = boto_session.client("logs") + logs_client = self._get_client(configuration, "logs") can_stream_output = True # Track the last log timestamp to prevent double display last_log_timestamp: Optional[int] = None @@ -1300,13 +1289,13 @@ def _prepare_task_definition( return task_definition def _load_network_configuration( - self, vpc_id: Optional[str], boto_session: boto3.Session + self, vpc_id: Optional[str], configuration: ECSJobConfiguration ) -> dict: """ Load settings from a specific VPC or the default VPC and generate a task run request's network configuration. """ - ec2_client = boto_session.client("ec2") + ec2_client = self._get_client(configuration, "ec2") vpc_message = "the default VPC" if not vpc_id else f"VPC with ID {vpc_id}" if not vpc_id: @@ -1347,13 +1336,16 @@ def _load_network_configuration( } def _custom_network_configuration( - self, vpc_id: str, network_configuration: dict, boto_session: boto3.Session + self, + vpc_id: str, + network_configuration: dict, + configuration: ECSJobConfiguration, ) -> dict: """ Load settings from a specific VPC or the default VPC and generate a task run request's network configuration. """ - ec2_client = boto_session.client("ec2") + ec2_client = self._get_client(configuration, "ec2") vpc_message = f"VPC with ID {vpc_id}" vpcs = ec2_client.describe_vpcs(VpcIds=[vpc_id]).get("Vpcs") @@ -1389,7 +1381,6 @@ def _custom_network_configuration( def _prepare_task_run_request( self, - boto_session: boto3.Session, configuration: ECSJobConfiguration, task_definition: dict, task_definition_arn: str, @@ -1422,7 +1413,7 @@ def _prepare_task_run_request( and not configuration.network_configuration ): task_run_request["networkConfiguration"] = self._load_network_configuration( - configuration.vpc_id, boto_session + configuration.vpc_id, configuration ) # Use networkConfiguration if supplied by user @@ -1435,7 +1426,7 @@ def _prepare_task_run_request( self._custom_network_configuration( configuration.vpc_id, configuration.network_configuration, - boto_session, + configuration, ) ) @@ -1628,7 +1619,7 @@ def _stop_task( f"{cluster!r}." ) - _, ecs_client = self._get_session_and_client(configuration) + ecs_client = self._get_client(configuration, "ecs") try: ecs_client.stop_task(cluster=cluster, task=task) except Exception as exc: diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index 077a178a..c9fbaabf 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -21,6 +21,7 @@ from tenacity import RetryError +from prefect_aws.credentials import _get_client_cached from prefect_aws.workers.ecs_worker import ( _TASK_DEFINITION_CACHE, ECS_DEFAULT_CONTAINER_NAME, @@ -2183,6 +2184,25 @@ async def test_retry_on_failed_task_start( assert run_task_mock.call_count == 3 +@pytest.mark.usefixtures("ecs_mocks") +async def test_worker_uses_cached_boto3_client(aws_credentials: AwsCredentials): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + ) + + _get_client_cached.cache_clear() + + assert _get_client_cached.cache_info().hits == 0, "Initial call count should be 0" + + async with ECSWorker(work_pool_name="test") as worker: + worker._get_client(configuration, "ecs") + worker._get_client(configuration, "ecs") + worker._get_client(configuration, "ecs") + + assert _get_client_cached.cache_info().misses == 1 + assert _get_client_cached.cache_info().hits == 2 + + async def test_mask_sensitive_env_values(): task_run_request = { "overrides": {