diff --git a/packages/aws-cdk-lib/aws-ecs/lib/drain-hook/lambda-source/index.py b/packages/aws-cdk-lib/aws-ecs/lib/drain-hook/lambda-source/index.py index 73502aca37ac2..dff24459d1744 100644 --- a/packages/aws-cdk-lib/aws-ecs/lib/drain-hook/lambda-source/index.py +++ b/packages/aws-cdk-lib/aws-ecs/lib/drain-hook/lambda-source/index.py @@ -6,16 +6,17 @@ def lambda_handler(event, context): print(json.dumps(dict(event, ResponseURL='...'))) + cluster = os.environ['CLUSTER'] - snsTopicArn = event['Records'][0]['Sns']['TopicArn'] lifecycle_event = json.loads(event['Records'][0]['Sns']['Message']) - instance_id = lifecycle_event.get('EC2InstanceId') + instance_id = lifecycle_event.get('EC2InstanceId', None) + if not instance_id: - print('Got event without EC2InstanceId: %s', json.dumps(dict(event, ResponseURL='...'))) + print(f'Got event without EC2InstanceId: {json.dumps(dict(event, ResponseURL="..."))}') return instance_arn = container_instance_arn(cluster, instance_id) - print('Instance %s has container instance ARN %s' % (lifecycle_event['EC2InstanceId'], instance_arn)) + print(f'Instance {lifecycle_event["EC2InstanceId"]} has container instance ARN {instance_arn}') if not instance_arn: return @@ -23,11 +24,15 @@ def lambda_handler(event, context): task_arns = container_instance_task_arns(cluster, instance_arn) if task_arns: - print('Instance ARN %s has task ARNs %s' % (instance_arn, ', '.join(task_arns))) + print(f'Instance ARN {instance_arn} has task ARNs {", ".join(task_arns)}') while has_tasks(cluster, instance_arn, task_arns): time.sleep(10) + complete_lifecycle_action(instance_id, lifecycle_event) + + +def complete_lifecycle_action(instance_id, lifecycle_event): try: print('Terminating instance %s' % instance_id) autoscaling.complete_lifecycle_action( @@ -40,19 +45,29 @@ def lambda_handler(event, context): def container_instance_arn(cluster, instance_id): """Turn an instance ID into a container instance ARN.""" - arns = ecs.list_container_instances(cluster=cluster, filter='ec2InstanceId==' + instance_id)['containerInstanceArns'] + arns = list_container_instances(cluster, instance_id) if not arns: return None return arns[0] + +def list_container_instances(cluster, instance_id): + return ecs.list_container_instances(cluster=cluster, filter='ec2InstanceId==' + instance_id)['containerInstanceArns'] + + def container_instance_task_arns(cluster, instance_arn): """Fetch tasks for a container instance ARN.""" - arns = ecs.list_tasks(cluster=cluster, containerInstance=instance_arn)['taskArns'] + arns = list_tasks(cluster, instance_arn) return arns + +def list_tasks(cluster, instance_arn): + return ecs.list_tasks(cluster=cluster, containerInstance=instance_arn)['taskArns'] + + def has_tasks(cluster, instance_arn, task_arns): """Return True if the instance is running tasks for the given cluster.""" - instances = ecs.describe_container_instances(cluster=cluster, containerInstances=[instance_arn])['containerInstances'] + instances = describe_container_instances(cluster, instance_arn) if not instances: return False instance = instances[0] @@ -66,7 +81,7 @@ def has_tasks(cluster, instance_arn, task_arns): if task_arns: # Fetch details for tasks running on the container instance - tasks = ecs.describe_tasks(cluster=cluster, tasks=task_arns)['tasks'] + tasks = describe_tasks(cluster, task_arns) if tasks: # Consider any non-stopped tasks as running task_count = sum(task['lastStatus'] != 'STOPPED' for task in tasks) + instance['pendingTasksCount'] @@ -79,6 +94,15 @@ def has_tasks(cluster, instance_arn, task_arns): return task_count > 0 + +def describe_container_instances(cluster, instance_arn): + return ecs.describe_container_instances(cluster=cluster, containerInstances=[instance_arn])['containerInstances'] + + +def describe_tasks(cluster, task_arns): + return ecs.describe_tasks(cluster=cluster, tasks=task_arns)['tasks'] + + def set_container_instance_to_draining(cluster, instance_arn): ecs.update_container_instances_state( cluster=cluster, diff --git a/packages/aws-cdk-lib/aws-ecs/test/drain-hook/Dockerfile b/packages/aws-cdk-lib/aws-ecs/test/drain-hook/Dockerfile new file mode 100644 index 0000000000000..79fe95a9d385e --- /dev/null +++ b/packages/aws-cdk-lib/aws-ecs/test/drain-hook/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/python:3.7 + +ADD . /opt/lambda +WORKDIR /opt/lambda + +RUN pip3 install boto3==1.17.42 +RUN python3 test_index.py + +ENTRYPOINT [ "/bin/bash" ] diff --git a/packages/aws-cdk-lib/aws-ecs/test/drain-hook/test.sh b/packages/aws-cdk-lib/aws-ecs/test/drain-hook/test.sh new file mode 100755 index 0000000000000..887efad19bc74 --- /dev/null +++ b/packages/aws-cdk-lib/aws-ecs/test/drain-hook/test.sh @@ -0,0 +1,27 @@ +#!/bin/bash +#--------------------------------------------------------------------------------------------------- +# executes unit tests +# +# prepares a staging directory with the requirements +set -e +script_dir=$(cd $(dirname $0) && pwd) + +# prepare staging directory +staging=$(mktemp -d) +mkdir -p ${staging} +cd ${staging} + +# copy src and overlay with test +cp ${script_dir}/../../lib/drain-hook/lambda-source/index.py $PWD +cp ${script_dir}/test_index.py $PWD +cp ${script_dir}/Dockerfile $PWD + +DRAIN_HOOK_TEST_NO_DOCKER=${DRAIN_HOOK_TEST_NO_DOCKER:-""} +DOCKER_CMD=${CDK_DOCKER:-docker} + +if [ -z ${DRAIN_HOOK_TEST_NO_DOCKER} ]; then + # this will run our tests inside the right environment + $DOCKER_CMD build . +else + python3 test_index.py +fi diff --git a/packages/aws-cdk-lib/aws-ecs/test/drain-hook/test_index.py b/packages/aws-cdk-lib/aws-ecs/test/drain-hook/test_index.py new file mode 100644 index 0000000000000..acf9c10d92241 --- /dev/null +++ b/packages/aws-cdk-lib/aws-ecs/test/drain-hook/test_index.py @@ -0,0 +1,114 @@ +import unittest +import os +import sys +from unittest.mock import patch + +os.environ["CLUSTER"] = "my-cluster" + +try: + # this is available only if executed with ./test.sh + import index +except ModuleNotFoundError as _: + print( + "Unable to import index. Use ./test.sh to run these tests. " + + 'If you want to avoid running them in docker, run "DRAIN_HOOK_TEST_NO_DOCKER=true ./test.sh"' + ) + sys.exit(1) + + +def make_event(): + records = [] + records.append({'Sns': {'Message': '{"EC2InstanceId": "i-xxxxxx", "LifecycleHookName": "my-hook", "LifecycleActionToken": "my-token", "AutoScalingGroupName": "my-asg"}'}}) + return {'Records': records} + + +def make_event_no_instance_id(): + records = [] + records.append({'Sns': {'Message': '{"food": "bar"}'}}) + return {'Records': records} + + +class DrainHookTest(unittest.TestCase): + @patch("index.list_container_instances") + def test_no_instance_id(self, list): + event = make_event_no_instance_id() + index.lambda_handler(event, {}) + list.assert_not_called() + + @patch("index.complete_lifecycle_action") + @patch("index.list_tasks") + @patch("index.list_container_instances") + def test_no_instance_arn(self, list, tasks, complete): + event = make_event() + + list.return_value = [] + index.lambda_handler(event, {}) + + list.assert_called_once_with( + os.environ["CLUSTER"], + "i-xxxxxx", + ) + tasks.assert_not_called() + complete.assert_not_called() + + @patch("index.complete_lifecycle_action") + @patch("index.describe_container_instances") + @patch("index.list_tasks") + @patch("index.list_container_instances") + def test_no_list_tasks_no_container_instances(self, list, tasks, describe, complete): + event = make_event() + + list.return_value = ['some-container-instance-arn'] + tasks.return_value = [] + describe.return_value = [] + index.lambda_handler(event, {}) + + list.assert_called_once_with( + os.environ["CLUSTER"], + "i-xxxxxx", + ) + tasks.assert_called_once_with( + os.environ["CLUSTER"], + 'some-container-instance-arn', + ) + describe.assert_called_once_with( + os.environ["CLUSTER"], + 'some-container-instance-arn', + ) + complete.assert_called_once() + + @patch("index.complete_lifecycle_action") + @patch("index.describe_tasks") + @patch("index.describe_container_instances") + @patch("index.list_tasks") + @patch("index.list_container_instances") + def test_has_list_tasks_no_describe_tasks(self, list, tasks, describe, describe_tasks, complete): + event = make_event() + + list.return_value = ['some-container-instance-arn'] + tasks.return_value = ['task-arn'] + describe.return_value = [{'id': 'i-xxxx', 'status': 'TERMINATED', 'runningTasksCount': 0, 'pendingTasksCount': 0}] + describe_tasks.return_value = [] + index.lambda_handler(event, {}) + + list.assert_called_once_with( + os.environ["CLUSTER"], + "i-xxxxxx", + ) + tasks.assert_called_once_with( + os.environ["CLUSTER"], + 'some-container-instance-arn', + ) + describe.assert_called_once_with( + os.environ["CLUSTER"], + 'some-container-instance-arn', + ) + describe_tasks.assert_called_once_with( + os.environ["CLUSTER"], + tasks.return_value, + ) + complete.assert_called_once() + + +if __name__ == "__main__": + unittest.main()