diff --git a/CHANGELOG.md b/CHANGELOG.md index 514f2ffa..653b1d94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed +## 0.4.6 + +Released December 11th, 2023. + +### Added + +Ability to publish `ECSTask`` block as an ecs work pool - [#353](https://github.com/PrefectHQ/prefect-aws/pull/353) + ## 0.4.5 Released November 30th, 2023. diff --git a/prefect_aws/ecs.py b/prefect_aws/ecs.py index a6ebe206..8e7052f2 100644 --- a/prefect_aws/ecs.py +++ b/prefect_aws/ecs.py @@ -108,6 +108,7 @@ import json import logging import pprint +import shlex import sys import time import warnings @@ -116,6 +117,8 @@ import boto3 import yaml from anyio.abc import TaskStatus +from jsonpointer import JsonPointerException +from prefect.blocks.core import BlockNotSavedError from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound from prefect.infrastructure.base import Infrastructure, InfrastructureResult from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible @@ -132,7 +135,7 @@ from typing_extensions import Literal, Self from prefect_aws import AwsCredentials -from prefect_aws.workers.ecs_worker import _TAG_REGEX +from prefect_aws.workers.ecs_worker import _TAG_REGEX, ECSWorker # Internal type alias for ECS clients which are generated dynamically in botocore _ECSClient = Any @@ -681,6 +684,75 @@ async def kill(self, identifier: str, grace_seconds: int = 30) -> None: cluster, task = parse_task_identifier(identifier) await run_sync_in_worker_thread(self._stop_task, cluster, task) + @staticmethod + def get_corresponding_worker_type() -> str: + """Return the corresponding worker type for this infrastructure block.""" + return ECSWorker.type + + async def generate_work_pool_base_job_template(self) -> dict: + """ + Generate a base job template for a cloud-run work pool with the same + configuration as this block. + + Returns: + - dict: a base job template for a cloud-run work pool + """ + base_job_template = copy.deepcopy(ECSWorker.get_default_base_job_template()) + for key, value in self.dict(exclude_unset=True, exclude_defaults=True).items(): + if key == "command": + base_job_template["variables"]["properties"]["command"]["default"] = ( + shlex.join(value) + ) + elif key in [ + "type", + "block_type_slug", + "_block_document_id", + "_block_document_name", + "_is_anonymous", + "task_customizations", + ]: + continue + elif key == "aws_credentials": + if not self.aws_credentials._block_document_id: + raise BlockNotSavedError( + "It looks like you are trying to use a block that" + " has not been saved. Please call `.save` on your block" + " before publishing it as a work pool." + ) + base_job_template["variables"]["properties"]["aws_credentials"][ + "default" + ] = { + "$ref": { + "block_document_id": str( + self.aws_credentials._block_document_id + ) + } + } + elif key == "task_definition": + base_job_template["job_configuration"]["task_definition"] = value + elif key in base_job_template["variables"]["properties"]: + base_job_template["variables"]["properties"][key]["default"] = value + else: + self.logger.warning( + f"Variable {key!r} is not supported by Cloud Run work pools." + " Skipping." + ) + + if self.task_customizations: + try: + base_job_template["job_configuration"]["task_run_request"] = ( + self.task_customizations.apply( + base_job_template["job_configuration"]["task_run_request"] + ) + ) + except JsonPointerException: + self.logger.warning( + "Unable to apply task customizations to the base job template." + "You may need to update the template manually." + ) + + return base_job_template + def _stop_task(self, cluster: str, task: str) -> None: """ Stop a running ECS task. diff --git a/requirements.txt b/requirements.txt index 919ce567..e5cfb0b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,5 @@ boto3>=1.24.53 botocore>=1.27.53 mypy_boto3_s3>=1.24.94 mypy_boto3_secretsmanager>=1.26.49 -prefect>=2.13.5 +prefect>=2.14.10 tenacity>=8.0.0 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index ea17328f..9d2da154 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,11 +22,13 @@ def prefect_db(): @pytest.fixture def aws_credentials(): - return AwsCredentials( + block = AwsCredentials( aws_access_key_id="access_key_id", aws_secret_access_key="secret_access_key", region_name="us-east-1", ) + block.save("test-creds-block", overwrite=True) + return block @pytest.fixture diff --git a/tests/test_ecs.py b/tests/test_ecs.py index cf18bfe4..2f970116 100644 --- a/tests/test_ecs.py +++ b/tests/test_ecs.py @@ -1,6 +1,7 @@ import json import logging import textwrap +from copy import deepcopy from functools import partial from typing import Any, Awaitable, Callable, Dict, List, Optional from unittest.mock import MagicMock @@ -18,6 +19,8 @@ from prefect.utilities.dockerutils import get_prefect_image_name from pydantic import VERSION as PYDANTIC_VERSION +from prefect_aws.workers.ecs_worker import ECSWorker + if PYDANTIC_VERSION.startswith("2."): from pydantic.v1 import ValidationError else: @@ -2047,3 +2050,189 @@ async def test_kill_with_grace_period(aws_credentials, caplog): # Logs warning assert "grace period of 60s requested, but AWS does not support" in caplog.text + + +@pytest.fixture +def default_base_job_template(): + return deepcopy(ECSWorker.get_default_base_job_template()) + + +@pytest.fixture +def base_job_template_with_defaults(default_base_job_template, aws_credentials): + base_job_template_with_defaults = deepcopy(default_base_job_template) + base_job_template_with_defaults["variables"]["properties"]["command"][ + "default" + ] = "python my_script.py" + base_job_template_with_defaults["variables"]["properties"]["env"]["default"] = { + "VAR1": "value1", + "VAR2": "value2", + } + base_job_template_with_defaults["variables"]["properties"]["labels"]["default"] = { + "label1": "value1", + "label2": "value2", + } + base_job_template_with_defaults["variables"]["properties"]["name"][ + "default" + ] = "prefect-job" + base_job_template_with_defaults["variables"]["properties"]["image"][ + "default" + ] = "docker.io/my_image:latest" + base_job_template_with_defaults["variables"]["properties"]["aws_credentials"][ + "default" + ] = {"$ref": {"block_document_id": str(aws_credentials._block_document_id)}} + base_job_template_with_defaults["variables"]["properties"]["launch_type"][ + "default" + ] = "FARGATE_SPOT" + base_job_template_with_defaults["variables"]["properties"]["vpc_id"][ + "default" + ] = "vpc-123456" + base_job_template_with_defaults["variables"]["properties"]["task_role_arn"][ + "default" + ] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" + base_job_template_with_defaults["variables"]["properties"]["execution_role_arn"][ + "default" + ] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" + base_job_template_with_defaults["variables"]["properties"]["cluster"][ + "default" + ] = "test-cluster" + base_job_template_with_defaults["variables"]["properties"]["cpu"]["default"] = 2048 + base_job_template_with_defaults["variables"]["properties"]["memory"][ + "default" + ] = 4096 + + base_job_template_with_defaults["variables"]["properties"]["family"][ + "default" + ] = "test-family" + base_job_template_with_defaults["variables"]["properties"]["task_definition_arn"][ + "default" + ] = "arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1" + base_job_template_with_defaults["variables"]["properties"][ + "cloudwatch_logs_options" + ]["default"] = { + "awslogs-group": "prefect", + "awslogs-region": "us-east-1", + "awslogs-stream-prefix": "prefect", + } + base_job_template_with_defaults["variables"]["properties"][ + "configure_cloudwatch_logs" + ]["default"] = True + base_job_template_with_defaults["variables"]["properties"]["stream_output"][ + "default" + ] = True + base_job_template_with_defaults["variables"]["properties"][ + "task_watch_poll_interval" + ]["default"] = 5.1 + base_job_template_with_defaults["variables"]["properties"][ + "task_start_timeout_seconds" + ]["default"] = 60 + base_job_template_with_defaults["variables"]["properties"][ + "auto_deregister_task_definition" + ]["default"] = False + return base_job_template_with_defaults + + +@pytest.fixture +def base_job_template_with_task_arn(default_base_job_template, aws_credentials): + base_job_template_with_task_arn = deepcopy(default_base_job_template) + base_job_template_with_task_arn["variables"]["properties"]["image"][ + "default" + ] = "docker.io/my_image:latest" + + base_job_template_with_task_arn["job_configuration"]["task_definition"] = { + "containerDefinitions": [ + {"image": "docker.io/my_image:latest", "name": "prefect-job"} + ], + "cpu": "2048", + "family": "test-family", + "memory": "2024", + "executionRoleArn": "arn:aws:iam::123456789012:role/ecsTaskExecutionRole", + } + return base_job_template_with_task_arn + + +@pytest.mark.parametrize( + "job_config", + [ + "default", + "custom", + "task_definition_arn", + ], +) +async def test_generate_work_pool_base_job_template( + job_config, + base_job_template_with_defaults, + aws_credentials, + default_base_job_template, + base_job_template_with_task_arn, + caplog, +): + job = ECSTask() + expected_template = default_base_job_template + expected_template["variables"]["properties"]["image"][ + "default" + ] = get_prefect_image_name() + if job_config == "custom": + expected_template = base_job_template_with_defaults + job = ECSTask( + command=["python", "my_script.py"], + env={"VAR1": "value1", "VAR2": "value2"}, + labels={"label1": "value1", "label2": "value2"}, + name="prefect-job", + image="docker.io/my_image:latest", + aws_credentials=aws_credentials, + launch_type="FARGATE_SPOT", + vpc_id="vpc-123456", + task_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole", + execution_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole", + cluster="test-cluster", + cpu=2048, + memory=4096, + task_customizations=[ + { + "op": "add", + "path": "/networkConfiguration/awsvpcConfiguration/securityGroups", + "value": ["sg-d72e9599956a084f5"], + }, + ], + family="test-family", + task_definition_arn=( + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1" + ), + cloudwatch_logs_options={ + "awslogs-group": "prefect", + "awslogs-region": "us-east-1", + "awslogs-stream-prefix": "prefect", + }, + configure_cloudwatch_logs=True, + stream_output=True, + task_watch_poll_interval=5.1, + task_start_timeout_seconds=60, + auto_deregister_task_definition=False, + ) + elif job_config == "task_definition_arn": + expected_template = base_job_template_with_task_arn + job = ECSTask( + image="docker.io/my_image:latest", + task_definition={ + "containerDefinitions": [ + {"image": "docker.io/my_image:latest", "name": "prefect-job"} + ], + "cpu": "2048", + "family": "test-family", + "memory": "2024", + "executionRoleArn": ( + "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" + ), + }, + ) + + template = await job.generate_work_pool_base_job_template() + + assert template == expected_template + + if job_config == "custom": + assert ( + "Unable to apply task customizations to the base job template." + "You may need to update the template manually." + in caplog.text + )