diff --git a/pyproject.toml b/pyproject.toml index d300691..2265053 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] requires-python = ">=3.7" dependencies = [ - "boto3 ~= 1.26", + "boto3 ~= 1.26", ] [project.entry-points.pytest11] @@ -58,11 +58,15 @@ files = [ "src/**/*.py" ] [[tool.mypy.overrides]] module = [ - "boto3", - "botocore.*" + "boto3", + "botocore.*" ] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "moto" +ignore_missing_imports = true + [tool.ruff] ignore = [ "E501", @@ -80,26 +84,31 @@ line-length = 100 [tool.pytest.ini_options] xfail_strict = true addopts = [ - "--durations=5", - "--cov=src/deadline_test_scaffolding", - "--color=yes", - "--cov-report=html:build/coverage", - "--cov-report=xml:build/coverage/coverage.xml", - "--cov-report=term-missing", - "--numprocesses=auto", + "--durations=5", + "--cov=src/deadline_test_scaffolding", + "--color=yes", + "--cov-report=html:build/coverage", + "--cov-report=xml:build/coverage/coverage.xml", + "--cov-report=term-missing", + "--numprocesses=auto", ] testpaths = [ "test" ] looponfailroots = [ - "src", - "test", + "src", + "test", ] # looponfailroots is deprecated, this removes the deprecation from the test output filterwarnings = [ - "ignore::DeprecationWarning" + "ignore::DeprecationWarning" ] [tool.coverage.run] source_pkgs = [ "deadline_test_scaffolding" ] +omit = [ + "models.py", + "fixtures.py", + "deadline/stubs.py", +] [tool.coverage.paths] @@ -108,4 +117,4 @@ source = [ ] [tool.coverage.report] -show_missing = true \ No newline at end of file +show_missing = true diff --git a/requirements-testing.txt b/requirements-testing.txt index 6c34d78..5178d15 100644 --- a/requirements-testing.txt +++ b/requirements-testing.txt @@ -4,6 +4,7 @@ pytest-cov ~= 4.1 pytest-timeout ~= 2.1 pytest-xdist ~= 3.3 black ~= 23.7 +moto[all] ~= 4.2 mypy == 1.5.0 ruff ~= 0.0.284 twine ~= 4.0 \ No newline at end of file diff --git a/src/deadline_test_scaffolding/__init__.py b/src/deadline_test_scaffolding/__init__.py index 9a05c03..83ecc60 100644 --- a/src/deadline_test_scaffolding/__init__.py +++ b/src/deadline_test_scaffolding/__init__.py @@ -1,17 +1,66 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -from .deadline_manager import DeadlineManager, DeadlineClient -from .deadline_stub import StubDeadlineClient -from .fixtures import deadline_manager_fixture, deadline_scaffolding, create_worker_agent +from .deadline import ( + CommandResult, + DeadlineClient, + DeadlineWorker, + DeadlineWorkerConfiguration, + DockerContainerWorker, + EC2InstanceWorker, + Job, + Farm, + Fleet, + PipInstall, + Queue, + QueueFleetAssociation, + TaskStatus, +) +from .fixtures import ( + BootstrapResources, + DeadlineResources, + bootstrap_resources, + deadline_client, + deadline_resources, + deploy_job_attachment_resources, + worker, +) from .job_attachment_manager import JobAttachmentManager +from .models import ( + CodeArtifactRepositoryInfo, + JobAttachmentSettings, + S3Object, + ServiceModel, +) from ._version import __version__ as version # noqa __all__ = [ - "DeadlineManager", + "BootstrapResources", + "CodeArtifactRepositoryInfo", + "CommandResult", + "DeadlineResources", "DeadlineClient", + "DeadlineScaffolding", + "DeadlineSubmitter", + "DeadlineJob", + "DeadlineWorker", + "DeadlineWorkerConfiguration", + "DockerContainerWorker", + "EC2InstanceWorker", + "Farm", + "Fleet", + "Job", + "JobAttachmentSettings", "JobAttachmentManager", - "deadline_manager_fixture", - "deadline_scaffolding", + "PipInstall", + "S3Object", + "ServiceModel", "StubDeadlineClient", + "Queue", + "QueueFleetAssociation", + "TaskStatus", + "bootstrap_resources", + "deadline_client", + "deadline_resources", + "deploy_job_attachment_resources", "version", - "create_worker_agent", + "worker", ] diff --git a/src/deadline_test_scaffolding/cf_templates/job_attachments.yaml b/src/deadline_test_scaffolding/cf_templates/job_attachments.yaml deleted file mode 100644 index a8faba2..0000000 --- a/src/deadline_test_scaffolding/cf_templates/job_attachments.yaml +++ /dev/null @@ -1,42 +0,0 @@ -AWSTemplateFormatVersion: 2010-09-09 -Parameters: - BucketName: - Type: String -Resources: - JobAttachmentBucket: - Type: AWS::S3::Bucket - Properties: - BucketName: !Ref BucketName - BucketEncryption: - ServerSideEncryptionConfiguration: - - ServerSideEncryptionByDefault: - SSEAlgorithm: AES256 - PublicAccessBlockConfiguration: - BlockPublicAcls: true - BlockPublicPolicy: true - IgnorePublicAcls: true - RestrictPublicBuckets: true - UpdateReplacePolicy: Delete - DeletionPolicy: Delete - # Deny all non-https traffic - JobAttachmentBucketPolicy: - Type: AWS::S3::BucketPolicy - Properties: - Bucket: - Ref: JobAttachmentBucket - PolicyDocument: - Statement: - - Action: s3:* - Condition: - Bool: - aws:SecureTransport: "false" - Effect: Deny - Principal: - AWS: "*" - Resource: - - !GetAtt JobAttachmentBucket.Arn - - !Join - - "" - - - !GetAtt JobAttachmentBucket.Arn - - /* - Version: "2012-10-17" diff --git a/src/deadline_test_scaffolding/cloudformation/__init__.py b/src/deadline_test_scaffolding/cloudformation/__init__.py new file mode 100644 index 0000000..4e600b8 --- /dev/null +++ b/src/deadline_test_scaffolding/cloudformation/__init__.py @@ -0,0 +1,7 @@ +from .job_attachments_bootstrap_stack import JobAttachmentsBootstrapStack +from .worker_bootstrap_stack import WorkerBootstrapStack + +__all__ = [ + "JobAttachmentsBootstrapStack", + "WorkerBootstrapStack", +] diff --git a/src/deadline_test_scaffolding/cloudformation/cfn.py b/src/deadline_test_scaffolding/cloudformation/cfn.py new file mode 100644 index 0000000..96b2610 --- /dev/null +++ b/src/deadline_test_scaffolding/cloudformation/cfn.py @@ -0,0 +1,242 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from __future__ import annotations + +import botocore.client +import botocore.exceptions +import json +import logging +import re + +from ..util import clean_kwargs + +LOG = logging.getLogger(__name__) + + +class CfnStack: + name: str + description: str + _resources: list[CfnResource] + _capabilities: list[str] | None + + def __init__( + self, *, name: str, description: str | None = None, capabilities: list[str] | None = None + ) -> None: + self.name = name + self.description = description or "Stack created by deadline-cloud-test-fixtures" + self._resources = [] + self._capabilities = capabilities + + def deploy( + self, + *, + cfn_client: botocore.client.BaseClient, + ) -> None: + LOG.info(f"Bootstrapping test resources by deploying CloudFormation stack {self.name}") + LOG.info(f"Attempting to update stack {self.name}") + kwargs = clean_kwargs( + { + "StackName": self.name, + "TemplateBody": json.dumps(self.template), + "Capabilities": self._capabilities, + } + ) + try: + cfn_client.update_stack(**kwargs) + waiter = cfn_client.get_waiter("stack_update_complete") + waiter.wait(StackName=self.name) + LOG.info("Stack update complete") + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Message"] == "No updates are to be performed.": + LOG.info("Stack is already up to date") + elif re.match(r"Stack.+does not exist", e.response["Error"]["Message"]): + LOG.info(f"Stack {self.name} does not exist yet. Creating new stack.") + cfn_client.create_stack( + **kwargs, + OnFailure="DELETE", + EnableTerminationProtection=False, + ) + waiter = cfn_client.get_waiter("stack_create_complete") + waiter.wait(StackName=self.name) + LOG.info("Stack create complete") + else: + LOG.exception(f"Unexpected error when attempting to update stack {self.name}: {e}") + raise + + def destroy(self, *, cfn_client: botocore.client.BaseClient) -> None: + cfn_client.delete_stack(StackName=self.name) + + def _add_resource(self, resource: CfnResource) -> None: # pragma: no cover + self._resources.append(resource) + + @property + def template(self) -> dict: + return { + "AWSTemplateFormatVersion": "2010-09-09", + "Description": self.description, + "Resources": {r.logical_name: r.template for r in self._resources}, + } + + +class CfnResource: + logical_name: str + type: str + properties: dict + update_replace_policy: str | None + deletion_policy: str | None + + def __init__( + self, + stack: CfnStack, + type: str, + logical_name: str, + properties: dict, + *, + update_replace_policy: str | None = None, + deletion_policy: str | None = None, + ) -> None: + self.logical_name = logical_name + self.type = type + self.properties = properties + self.update_replace_policy = update_replace_policy + self.deletion_policy = deletion_policy + + stack._add_resource(self) + + @property + def template(self) -> dict: + template = { + "Type": self.type, + "Properties": self.properties, + } + if self.update_replace_policy: + template["UpdateReplacePolicy"] = self.update_replace_policy + if self.deletion_policy: + template["DeletionPolicy"] = self.deletion_policy + return template + + @property + def _physical_name_prop(self) -> str | None: + return None + + @property + def physical_name(self) -> str: + if self._physical_name_prop is None: + raise ValueError(f"Resource type {self.type} does not have a physical name") + if self._physical_name_prop not in self.properties: + raise ValueError( + f"Physical name was not specified for this resource ({self.logical_name})" + ) + return self.properties[self._physical_name_prop] + + @property + def ref(self) -> dict: # pragma: no cover + return {"Ref": self.logical_name} + + def get_att(self, name: str) -> dict: # pragma: no cover + return {"Fn::GetAtt": [self.logical_name, name]} + + +class Bucket(CfnResource): # pragma: no cover + _physical_name_prop = "BucketName" + + def __init__( + self, + stack: CfnStack, + logical_name: str, + *, + bucket_name: str | None = None, + **kwargs, + ) -> None: + props = clean_kwargs( + { + "BucketName": bucket_name, + # Always apply secure bucket settings + "BucketEncryption": { + "ServerSideEncryptionConfiguration": [ + {"ServerSideEncryptionByDefault": {"SSEAlgorithm": "AES256"}}, + ], + }, + "PublicAccessBlockConfiguration": { + "BlockPublicAcls": True, + "BlockPublicPolicy": True, + "IgnorePublicAcls": True, + "RestrictPublicBuckets": True, + }, + } + ) + super().__init__(stack, "AWS::S3::Bucket", logical_name, props, **kwargs) + + def arn_for_objects(self, *, pattern: str = "*") -> str: + return f"{self.arn}/{pattern}" + + @property + def arn(self) -> str: + return f"arn:aws:s3:::{self.physical_name}" + + +class BucketPolicy(CfnResource): # pragma: no cover + def __init__( + self, + stack: CfnStack, + logical_name: str, + *, + bucket: Bucket, + policy_document: dict, + **kwargs, + ) -> None: + props = clean_kwargs( + { + "Bucket": bucket.ref, + "PolicyDocument": policy_document, + } + ) + super().__init__(stack, "AWS::S3::BucketPolicy", logical_name, props, **kwargs) + + +class Role(CfnResource): # pragma: no cover + _physical_name_prop = "RoleName" + + def __init__( + self, + stack: CfnStack, + logical_name: str, + *, + assume_role_policy_document: dict, + role_name: str | None = None, + policies: list[dict] | None = None, + managed_policy_arns: list[str] | None = None, + **kwargs, + ) -> None: + props = clean_kwargs( + { + "AssumeRolePolicyDocument": assume_role_policy_document, + "RoleName": role_name, + "Policies": policies, + "ManagedPolicyArns": managed_policy_arns, + } + ) + super().__init__(stack, "AWS::IAM::Role", logical_name, props, **kwargs) + + def format_arn(self, *, account: str) -> str: + return f"arn:aws:iam::{account}:role/{self.physical_name}" + + +class InstanceProfile(CfnResource): # pragma: no cover + _physical_name_prop = "InstanceProfileName" + + def __init__( + self, + stack: CfnStack, + logical_name: str, + *, + roles: list[Role], + instance_profile_name: str | None = None, + **kwargs, + ) -> None: + props = clean_kwargs( + { + "Roles": [role.ref for role in roles], + "InstanceProfileName": instance_profile_name, + } + ) + super().__init__(stack, "AWS::IAM::InstanceProfile", logical_name, props, **kwargs) diff --git a/src/deadline_test_scaffolding/cloudformation/job_attachments_bootstrap_stack.py b/src/deadline_test_scaffolding/cloudformation/job_attachments_bootstrap_stack.py new file mode 100644 index 0000000..9b05a3b --- /dev/null +++ b/src/deadline_test_scaffolding/cloudformation/job_attachments_bootstrap_stack.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from __future__ import annotations + +from .cfn import ( + Bucket, + BucketPolicy, + CfnStack, +) + + +class JobAttachmentsBootstrapStack(CfnStack): # pragma: no cover + bucket: Bucket + bucket_policy: BucketPolicy + + def __init__( + self, + *, + name: str, + bucket_name: str, + description: str | None = None, + ) -> None: + super().__init__(name=name, description=description) + + self.bucket = Bucket( + self, + "JobAttachmentBucket", + bucket_name=bucket_name, + update_replace_policy="Delete", + deletion_policy="Delete", + ) + self.bucket_policy = BucketPolicy( + self, + "JobAttachmentBucketPolicy", + bucket=self.bucket, + policy_document={ + "Version": "2012-10-17", + "Statement": [ + { + "Action": "s3:*", + "Effect": "Deny", + "Principal": "*", + "Resource": [ + self.bucket.arn, + self.bucket.arn_for_objects(), + ], + "Condition": {"Bool": {"aws:SecureTransport": "false"}}, + }, + ], + }, + ) diff --git a/src/deadline_test_scaffolding/utils.py b/src/deadline_test_scaffolding/cloudformation/worker_bootstrap_stack.py similarity index 53% rename from src/deadline_test_scaffolding/utils.py rename to src/deadline_test_scaffolding/cloudformation/worker_bootstrap_stack.py index b903b63..2bdcb5c 100644 --- a/src/deadline_test_scaffolding/utils.py +++ b/src/deadline_test_scaffolding/cloudformation/worker_bootstrap_stack.py @@ -1,129 +1,89 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from __future__ import annotations -from .constants import ( - DEADLINE_WORKER_ROLE, - DEADLINE_WORKER_BOOTSTRAP_ROLE, - DEADLINE_WORKER_BOOSTRAP_INSTANCE_PROFILE_NAME, - JOB_ATTACHMENTS_BUCKET_NAME, - JOB_ATTACHMENTS_BUCKET_RESOURCE, - DEADLINE_SERVICE_MODEL_BUCKET, - CODEARTIFACT_DOMAIN, - CODEARTIFACT_ACCOUNT_ID, - CODEARTIFACT_REPOSITORY, - DEADLINE_QUEUE_SESSION_ROLE, - CREDENTIAL_VENDING_PRINCIPAL, +from .cfn import ( + Bucket, + BucketPolicy, + CfnResource, + CfnStack, + InstanceProfile, + Role, ) +from ..models import CodeArtifactRepositoryInfo -from typing import Any, Dict +class WorkerBootstrapStack(CfnStack): # pragma: no cover + worker_role: Role + worker_bootstrap_role: Role + worker_instance_profile: InstanceProfile + session_role: Role + job_attachments_bucket: Bucket + job_attachments_bucket_policy: CfnResource + bootstrap_bucket: Bucket + bootstrap_bucket_policy: CfnResource -# IAM Roles -def generate_boostrap_worker_role_cfn_template() -> Dict[str, Any]: - cfn_template = { - "Type": "AWS::IAM::Role", - "Properties": { - "RoleName": DEADLINE_WORKER_BOOTSTRAP_ROLE, - "Description": DEADLINE_WORKER_BOOTSTRAP_ROLE, - "AssumeRolePolicyDocument": { + def __init__( + self, + *, + name: str, + account: str, + credential_vending_service_principal: str, + codeartifact: CodeArtifactRepositoryInfo, + description: str | None = None, + service_model_s3_object_arn: str | None = None, + ) -> None: + super().__init__( + name=name, + description=description, + capabilities=["CAPABILITY_NAMED_IAM"], + ) + + self.bootstrap_bucket = Bucket( + self, + "BootstrapBucket", + bucket_name=f"deadline-scaffolding-worker-bootstrap-{account}", + update_replace_policy="Delete", + deletion_policy="Delete", + ) + + self.bootstrap_bucket_policy = BucketPolicy( + self, + "BootstrapBucketPolicy", + bucket=self.bootstrap_bucket, + policy_document={ "Version": "2012-10-17", "Statement": [ { - "Effect": "Allow", - "Principal": {"Service": "ec2.amazonaws.com"}, - "Action": "sts:AssumeRole", - } - ], - }, - "Policies": [ - { - "PolicyName": f"{DEADLINE_WORKER_BOOTSTRAP_ROLE}Policy", - "PolicyDocument": { - "Version": "2012-10-17", - "Statement": [ - # Allows the worker to bootstrap itself and grab the correct credentials - { - "Effect": "Allow", - "Action": [ - "deadline:CreateWorker", - "deadline:GetWorkerIamCredentials", - "deadline:AssumeFleetRoleForWorker", - ], - "Resource": "*", - }, - # Allows the worker to download service model - { - "Action": ["s3:GetObject", "s3:HeadObject"], - "Resource": [ - f"arn:aws:s3:::{DEADLINE_SERVICE_MODEL_BUCKET}/service-2.json" - ], - "Effect": "Allow", - }, - # Allows access to code artifact - { - "Action": ["codeartifact:GetAuthorizationToken"], - "Resource": [ - f"arn:aws:codeartifact:us-west-2:{CODEARTIFACT_ACCOUNT_ID}:domain/{CODEARTIFACT_DOMAIN}" - ], - "Effect": "Allow", - }, - { - "Action": ["sts:GetServiceBearerToken"], - "Resource": "*", - "Effect": "Allow", - }, - { - "Action": [ - "codeartifact:ReadFromRepository", - "codeartifact:GetRepositoryEndpoint", - ], - "Resource": [ - f"arn:aws:codeartifact:us-west-2:{CODEARTIFACT_ACCOUNT_ID}:repository/{CODEARTIFACT_DOMAIN}/{CODEARTIFACT_REPOSITORY}" - ], - "Effect": "Allow", - }, + "Action": "s3:*", + "Effect": "Deny", + "Principal": "*", + "Resource": [ + self.bootstrap_bucket.arn, + self.bootstrap_bucket.arn_for_objects(), ], + "Condition": {"Bool": {"aws:SecureTransport": "false"}}, }, - }, - ], - "ManagedPolicyArns": ["arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore"], - }, - } - return cfn_template - - -def generate_boostrap_instance_profile_cfn_template() -> Dict[str, Any]: - cfn_template = { - "Type": "AWS::IAM::InstanceProfile", - "Properties": { - "InstanceProfileName": DEADLINE_WORKER_BOOSTRAP_INSTANCE_PROFILE_NAME, - "Roles": [ - {"Ref": DEADLINE_WORKER_BOOTSTRAP_ROLE}, - ], - }, - } - return cfn_template - + ], + }, + ) -def generate_worker_role_cfn_template() -> Dict[str, Any]: - """This role matches the worker role of the closed-beta console""" - cfn_template = { - "Type": "AWS::IAM::Role", - "Properties": { - "RoleName": DEADLINE_WORKER_ROLE, - "Description": DEADLINE_WORKER_ROLE, - "AssumeRolePolicyDocument": { + self.worker_role = Role( + self, + "WorkerRole", + role_name="DeadlineScaffoldingWorkerRole", + assume_role_policy_document={ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", - "Principal": {"Service": CREDENTIAL_VENDING_PRINCIPAL}, + "Principal": {"Service": credential_vending_service_principal}, "Action": "sts:AssumeRole", - } + }, ], }, - "Policies": [ + policies=[ { - "PolicyName": f"{DEADLINE_WORKER_ROLE}Policy", + "PolicyName": "DeadlineScaffoldingWorkerRolePolicy", "PolicyDocument": { "Version": "2012-10-17", "Statement": [ @@ -173,98 +133,158 @@ def generate_worker_role_cfn_template() -> Dict[str, Any]: }, }, ], - }, - } - return cfn_template + ) - -def generate_queue_session_role() -> Dict[str, Any]: - cfn_template = { - "Type": "AWS::IAM::Role", - "Properties": { - "RoleName": DEADLINE_QUEUE_SESSION_ROLE, - "Description": DEADLINE_QUEUE_SESSION_ROLE, - "AssumeRolePolicyDocument": { + self.worker_bootstrap_role = Role( + self, + "WorkerBootstrapRole", + role_name="DeadlineScaffoldingWorkerBootstrapRole", + assume_role_policy_document={ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", - "Principal": {"Service": CREDENTIAL_VENDING_PRINCIPAL}, + "Principal": {"Service": "ec2.amazonaws.com"}, "Action": "sts:AssumeRole", - } + }, ], }, - "Policies": [ + managed_policy_arns=[ + "arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore", + ], + policies=[ { - "PolicyName": f"{DEADLINE_QUEUE_SESSION_ROLE}Policy", + "PolicyName": "DeadlineScaffoldingWorkerBootstrapRolePolicy", "PolicyDocument": { "Version": "2012-10-17", "Statement": [ + # Allows the worker to bootstrap itself and grab the correct credentials + { + "Effect": "Allow", + "Action": [ + "deadline:CreateWorker", + "deadline:GetWorkerIamCredentials", + "deadline:AssumeFleetRoleForWorker", + ], + "Resource": "*", + }, + # Allow access to bootstrap bucket { "Effect": "Allow", "Action": [ + "s3:HeadObject", "s3:GetObject", - "s3:PutObject", - "s3:ListBucket", - "s3:GetBucketLocation", ], - "Resource": [ - f"arn:aws:s3:::{JOB_ATTACHMENTS_BUCKET_NAME}" - f"arn:aws:s3:::{JOB_ATTACHMENTS_BUCKET_NAME}/*" + "Resource": [self.bootstrap_bucket.arn_for_objects()], + }, + # Allows the worker to download service model + *( + [ + { + "Action": ["s3:GetObject", "s3:HeadObject"], + "Resource": [service_model_s3_object_arn], + "Effect": "Allow", + } + ] + if service_model_s3_object_arn + else [] + ), + # Allows access to code artifact + { + "Action": ["codeartifact:GetAuthorizationToken"], + "Resource": [codeartifact.domain_arn], + "Effect": "Allow", + }, + { + "Action": ["sts:GetServiceBearerToken"], + "Resource": "*", + "Effect": "Allow", + }, + { + "Action": [ + "codeartifact:ReadFromRepository", + "codeartifact:GetRepositoryEndpoint", ], - } + "Resource": [codeartifact.repository_arn], + "Effect": "Allow", + }, ], }, - } + }, ], - }, - } - - return cfn_template + ) + self.worker_instance_profile = InstanceProfile( + self, + "WorkerBootstrapInstanceProfile", + instance_profile_name="DeadlineScaffoldingWorkerBootstrapInstanceProfile", + roles=[self.worker_bootstrap_role], + ) -# Job Attachments Bucket -def generate_job_attachments_bucket() -> Dict[str, Any]: - cfn_template = { - "Type": "AWS::S3::Bucket", - "Properties": { - "BucketName": JOB_ATTACHMENTS_BUCKET_NAME, - "BucketEncryption": { - "ServerSideEncryptionConfiguration": [ - {"ServerSideEncryptionByDefault": {"SSEAlgorithm": "AES256"}} - ] - }, - "PublicAccessBlockConfiguration": { - "BlockPublicAcls": True, - "BlockPublicPolicy": True, - "IgnorePublicAcls": True, - "RestrictPublicBuckets": True, - }, - }, - "UpdateReplacePolicy": "Delete", - "DeletionPolicy": "Delete", - } + self.job_attachments_bucket = Bucket( + self, + "JobAttachmentsBucket", + bucket_name=f"deadline-scaffolding-worker-job-attachments-{account}", + update_replace_policy="Delete", + deletion_policy="Delete", + ) - return cfn_template - - -def generate_job_attachments_bucket_policy() -> Dict[str, Any]: - cfn_template = { - "Type": "AWS::S3::BucketPolicy", - "Properties": { - "Bucket": {"Ref": JOB_ATTACHMENTS_BUCKET_RESOURCE}, - "PolicyDocument": { + self.job_attachments_bucket_policy = BucketPolicy( + self, + "JobAttachmentsBucketPolicy", + bucket=self.job_attachments_bucket, + policy_document={ + "Version": "2012-10-17", "Statement": [ { "Action": "s3:*", "Effect": "Deny", "Principal": "*", - "Resource": f"arn:aws:s3:::{JOB_ATTACHMENTS_BUCKET_NAME}/*", + "Resource": [ + self.job_attachments_bucket.arn, + self.job_attachments_bucket.arn_for_objects(), + ], "Condition": {"Bool": {"aws:SecureTransport": "false"}}, - } - ] + }, + ], }, - }, - } + ) - return cfn_template + self.session_role = Role( + self, + "SessionRole", + role_name="DeadlineScaffoldingWorkerSessionRole", + assume_role_policy_document={ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": credential_vending_service_principal}, + "Action": "sts:AssumeRole", + }, + ], + }, + policies=[ + { + "PolicyName": "DeadlineScaffoldingWorkerSessionRolePolicy", + "PolicyDocument": { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:ListBucket", + "s3:GetBucketLocation", + ], + "Resource": [ + self.job_attachments_bucket.arn, + self.job_attachments_bucket.arn_for_objects(), + ], + }, + ], + }, + }, + ], + ) diff --git a/src/deadline_test_scaffolding/constants.py b/src/deadline_test_scaffolding/constants.py deleted file mode 100644 index d3186fd..0000000 --- a/src/deadline_test_scaffolding/constants.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -import os - -STAGE = os.environ.get("STAGE", "Prod") - -BOOTSTRAP_CLOUDFORMATION_STACK_NAME = f"TestScaffoldingStack{STAGE}" - -# Role Names -DEADLINE_WORKER_BOOTSTRAP_ROLE = f"DeadlineWorkerBootstrapRole{STAGE}" -DEADLINE_WORKER_BOOSTRAP_INSTANCE_PROFILE_NAME = f"DeadlineWorkerBootstrapInstanceProfile{STAGE}" -DEADLINE_WORKER_ROLE = f"DeadlineWorkerTestRole{STAGE}" -DEADLINE_QUEUE_SESSION_ROLE = f"DeadlineScaffoldingQueueSessionRole{STAGE}" - -# Job Attachments -JOB_ATTACHMENTS_BUCKET_RESOURCE = "ScaffoldingJobAttachmentsBucket" -JOB_ATTACHMENTS_BUCKET_NAME = os.environ.get( - "JOB_ATTACHMENTS_BUCKET_NAME", "scaffolding-job-attachments-bucket" -) -JOB_ATTACHMENTS_BUCKET_POLICY_RESOURCE = f"JobAttachmentsPolicy{STAGE}" -JOB_ATTACHMENTS_ROOT_PREFIX = "root" - -# Worker Agent Configurations -DEFAULT_CMF_CONFIG = { - "customerManaged": { - "autoScalingConfiguration": { - "mode": "NO_SCALING", - "maxFleetSize": 1, - }, - "workerRequirements": { - "vCpuCount": {"min": 1}, - "memoryMiB": {"min": 1024}, - "osFamily": "linux", - "cpuArchitectureType": "x86_64", - }, - } -} - -# Service Principals -CREDENTIAL_VENDING_PRINCIPAL = os.environ.get( - "CREDENTIAL_VENDING_PRINCIPAL", "credential-vending.deadline-closed-beta.amazonaws.com" -) - -# Temporary constants -DEADLINE_SERVICE_MODEL_BUCKET = os.environ.get("DEADLINE_SERVICE_MODEL_BUCKET", "") -CODEARTIFACT_DOMAIN = os.environ.get("CODEARTIFACT_DOMAIN", "") -CODEARTIFACT_ACCOUNT_ID = os.environ.get("CODEARTIFACT_ACCOUNT_ID", "") -CODEARTIFACT_REPOSITORY = os.environ.get("CODEARTIFACT_REPOSITORY", "") diff --git a/src/deadline_test_scaffolding/containers/worker/Dockerfile b/src/deadline_test_scaffolding/containers/worker/Dockerfile new file mode 100644 index 0000000..3a5ae52 --- /dev/null +++ b/src/deadline_test_scaffolding/containers/worker/Dockerfile @@ -0,0 +1,102 @@ +# syntax=docker/dockerfile:1 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +FROM public.ecr.aws/docker/library/python:3.9-buster + +ARG AGENT_USER=agentuser +ARG JOB_USER=jobuser +ARG SHARED_GROUP=sharedgroup +ARG CONFIGURE_WORKER_AGENT_CMD +ARG FILE_MAPPINGS + +ARG AWS_ACCESS_KEY_ID +ARG AWS_SECRET_ACCESS_KEY +ARG AWS_SESSION_TOKEN +ARG AWS_DEFAULT_REGION + +RUN < /etc/sudoers.d/$AGENT_USER +END_RUN + +COPY --chown=$AGENT_USER:$AGENT_USER file_mappings /file_mappings + +COPY --chown=root:root <<-EOF /entrypoint.sh +#!/bin/bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +set -euxo pipefail + +# Copy over file mappings +file_mappings='$FILE_MAPPINGS' +file_mappings=\$(echo "\$file_mappings" | jq -cr 'to_entries[] | "\\(.key):\\(.value)"') +for mapping in \$file_mappings; do + IFS=: read -r src dst <<< "\$mapping" + mkdir -p "\$(dirname \$dst)" + mv "\$src" "\$dst" +done + +# Configure the Worker agent +$CONFIGURE_WORKER_AGENT_CMD + +sudo --preserve-env -H -u $AGENT_USER "\$@" +EOF + +COPY --chown=$AGENT_USER:$AGENT_USER --chmod=750 <<-EOF /home/$AGENT_USER/run_agent.sh +#!/bin/bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +set -euxo pipefail + +deadline-worker-agent --allow-instance-profile --no-shutdown +EOF + +ENV AGENT_USER $AGENT_USER +USER root +WORKDIR / +ENTRYPOINT ["/bin/bash", "/entrypoint.sh"] +CMD ["/bin/bash", "-c", "/home/$AGENT_USER/run_agent.sh"] diff --git a/src/deadline_test_scaffolding/containers/worker/run_container.sh b/src/deadline_test_scaffolding/containers/worker/run_container.sh new file mode 100755 index 0000000..ab662ce --- /dev/null +++ b/src/deadline_test_scaffolding/containers/worker/run_container.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +set -eu + +tmp_env_file=$(mktemp -p $(pwd)) +for var in AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY AWS_SESSION_TOKEN AWS_DEFAULT_REGION FARM_ID FLEET_ID +do + if test "${!var:-}" == ""; + then + echo "ERROR: Environment variable $var must be set" + exit 1 + fi + echo -n "$var=" >> $tmp_env_file + printenv $var >> $tmp_env_file +done + +if test "$PIP_INDEX_URL" != ""; then + echo "PIP_INDEX_URL=$PIP_INDEX_URL" >> $tmp_env_file +fi + +FILE_MAPPINGS=${FILE_MAPPINGS:-} +if [ -z "$FILE_MAPPINGS" ]; then + # Put a dummy file so Dockerfile COPY command still has something to copy + dummy_file_path="$(pwd)/file_mappings/dummy_file" + touch "$dummy_file_path" + FILE_MAPPINGS="{\"$dummy_file_path\": \"$dummy_file_path\"}" +fi + +# Use Docker BuildKit to use new features like heredoc support in Dockerfile +container_image_tag=agent_integ +DOCKER_BUILDKIT=1 docker build . -q -t $container_image_tag \ + --build-arg AWS_ACCESS_KEY_ID \ + --build-arg AWS_SECRET_ACCESS_KEY \ + --build-arg AWS_SESSION_TOKEN \ + --build-arg AWS_DEFAULT_REGION \ + --build-arg AGENT_USER \ + --build-arg JOB_USER \ + --build-arg SHARED_GROUP \ + --build-arg CONFIGURE_WORKER_AGENT_CMD \ + --build-arg FILE_MAPPINGS + +docker run \ + --rm \ + --detach \ + --name integ_worker_agent \ + --env-file $tmp_env_file \ + -h worker-integ.environment.internal \ + --cidfile $(pwd)/.container_id \ + $container_image_tag:latest + +rm -f $tmp_env_file \ No newline at end of file diff --git a/src/deadline_test_scaffolding/deadline/__init__.py b/src/deadline_test_scaffolding/deadline/__init__.py new file mode 100644 index 0000000..9352296 --- /dev/null +++ b/src/deadline_test_scaffolding/deadline/__init__.py @@ -0,0 +1,33 @@ +from .resources import ( + Farm, + Fleet, + Job, + Queue, + QueueFleetAssociation, + TaskStatus, +) +from .client import DeadlineClient +from .worker import ( + CommandResult, + DeadlineWorker, + DeadlineWorkerConfiguration, + DockerContainerWorker, + EC2InstanceWorker, + PipInstall, +) + +__all__ = [ + "CommandResult", + "DeadlineClient", + "DeadlineWorker", + "DeadlineWorkerConfiguration", + "DockerContainerWorker", + "EC2InstanceWorker", + "Farm", + "Fleet", + "Job", + "PipInstall", + "Queue", + "QueueFleetAssociation", + "TaskStatus", +] diff --git a/src/deadline_test_scaffolding/deadline/client.py b/src/deadline_test_scaffolding/deadline/client.py new file mode 100644 index 0000000..d1d4834 --- /dev/null +++ b/src/deadline_test_scaffolding/deadline/client.py @@ -0,0 +1,165 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import logging +import os +from typing import Any, Optional + +from botocore.loaders import Loader +from botocore.model import ServiceModel, OperationModel + +LOG = logging.getLogger(__name__) + + +class DeadlineClient: + """ + A shim layer for boto Deadline client. This class will check if a method exists on the real + boto3 Deadline client and call it if it exists. If it doesn't exist, an AttributeError will be raised. + """ + + _real_client: Any + + def __init__(self, real_client: Any) -> None: + self._real_client = real_client + + def create_farm(self, *args, **kwargs) -> Any: + create_farm_input_members = self._get_deadline_api_input_shape("CreateFarm") + if "displayName" not in create_farm_input_members and "name" in create_farm_input_members: + kwargs["name"] = kwargs.pop("displayName") + return self._real_client.create_farm(*args, **kwargs) + + def create_fleet(self, *args, **kwargs) -> Any: + create_fleet_input_members = self._get_deadline_api_input_shape("CreateFleet") + if "displayName" not in create_fleet_input_members and "name" in create_fleet_input_members: + kwargs["name"] = kwargs.pop("displayName") + if ( + "roleArn" not in create_fleet_input_members + and "workeRoleArn" in create_fleet_input_members + ): + kwargs["workerRoleArn"] = kwargs.pop("roleArn") + return self._real_client.create_fleet(*args, **kwargs) + + def get_fleet(self, *args, **kwargs) -> Any: + response = self._real_client.get_fleet(*args, **kwargs) + if "name" in response and "displayName" not in response: + response["displayName"] = response["name"] + del response["name"] + if "state" in response and "status" not in response: + response["status"] = response["state"] + del response["state"] + if "type" in response: + del response["type"] + return response + + def get_queue_fleet_association(self, *args, **kwargs) -> Any: + response = self._real_client.get_queue_fleet_association(*args, **kwargs) + if "state" in response and "status" not in response: + response["status"] = response["state"] + del response["state"] + return response + + def create_queue(self, *args, **kwargs) -> Any: + create_queue_input_members = self._get_deadline_api_input_shape("CreateQueue") + if "displayName" not in create_queue_input_members and "name" in create_queue_input_members: + kwargs["name"] = kwargs.pop("displayName") + return self._real_client.create_queue(*args, **kwargs) + + def create_queue_fleet_association(self, *args, **kwargs) -> Any: + create_queue_fleet_association_method_name: Optional[str] + create_queue_fleet_association_method: Optional[str] + + for create_queue_fleet_association_method_name in ( + "put_queue_fleet_association", + "create_queue_fleet_association", + ): + create_queue_fleet_association_method = getattr( + self._real_client, create_queue_fleet_association_method_name, None + ) + if create_queue_fleet_association_method: + break + else: + create_queue_fleet_association_method = None + + # mypy complains about they kwargs type + return create_queue_fleet_association_method(*args, **kwargs) # type: ignore + + def create_job(self, *args, **kwargs) -> Any: + create_job_input_members = self._get_deadline_api_input_shape("CreateJob") + # revert to old parameter names if old service model is used + if "maxRetriesPerTask" in kwargs: + if "maxErrorsPerTask" in create_job_input_members: + kwargs["maxErrorsPerTask"] = kwargs.pop("maxRetriesPerTask") + if "template" in kwargs: + if "jobTemplate" in create_job_input_members: + kwargs["jobTemplate"] = kwargs.pop("template") + kwargs["jobTemplateType"] = kwargs.pop("templateType") + if "parameters" in kwargs: + kwargs["jobParameters"] = kwargs.pop("parameters") + if "targetTaskRunStatus" in kwargs: + if "initialState" in create_job_input_members: + kwargs["initialState"] = kwargs.pop("targetTaskRunStatus") + if "priority" not in kwargs: + kwargs["priority"] = 50 + return self._real_client.create_job(*args, **kwargs) + + def update_queue_fleet_association(self, *args, **kwargs) -> Any: + update_queue_fleet_association_method_name: Optional[str] + update_queue_fleet_association_method: Optional[str] + + for update_queue_fleet_association_method_name in ( + "update_queue_fleet_association", + "update_queue_fleet_association_state", + ): + update_queue_fleet_association_method = getattr( + self._real_client, update_queue_fleet_association_method_name, None + ) + if update_queue_fleet_association_method: + break + else: + update_queue_fleet_association_method = None + + if update_queue_fleet_association_method_name == "update_queue_fleet_association": + # mypy complains about they kwargs type + return update_queue_fleet_association_method(*args, **kwargs) # type: ignore + + if update_queue_fleet_association_method_name == "update_queue_fleet_association_state": + kwargs["state"] = kwargs.pop("status") + # mypy complains about they kwargs type + return update_queue_fleet_association_method(*args, **kwargs) # type: ignore + + def _get_deadline_api_input_shape(self, api_name: str) -> dict[str, Any]: + """ + Given a string name of an API e.g. CreateJob, returns the shape of the + inputs to that API. + """ + api_model = self._get_deadline_api_model(api_name) + if api_model: + return api_model.input_shape.members + return {} + + def _get_deadline_api_model(self, api_name: str) -> Optional[OperationModel]: + """ + Given a string name of an API e.g. CreateJob, returns the OperationModel + for that API from the service model. + """ + data_model_path = os.getenv("AWS_DATA_PATH") + loader = Loader(extra_search_paths=[data_model_path] if data_model_path is not None else []) + deadline_service_description = loader.load_service_model("deadline", "service-2") + deadline_service_model = ServiceModel(deadline_service_description, service_name="deadline") + return OperationModel( + deadline_service_description["operations"][api_name], deadline_service_model + ) + + def __getattr__(self, __name: str) -> Any: + """ + Respond to unknown method calls by calling the underlying _real_client + If the underlying _real_client does not have a given method, an AttributeError + will be raised. + Note that __getattr__ is only called if the attribute cannot otherwise be found, + so if this class alread has the called method defined, __getattr__ will not be called. + This is in opposition to __getattribute__ which is called by default. + """ + + def method(*args, **kwargs): + return getattr(self._real_client, __name)(*args, **kwargs) + + return method diff --git a/src/deadline_test_scaffolding/deadline/resources.py b/src/deadline_test_scaffolding/deadline/resources.py new file mode 100644 index 0000000..4e97973 --- /dev/null +++ b/src/deadline_test_scaffolding/deadline/resources.py @@ -0,0 +1,555 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from __future__ import annotations + +import datetime +import json +import logging +from dataclasses import dataclass, fields +from enum import Enum +from typing import Any, Callable, Literal + +from .client import DeadlineClient +from ..models import JobAttachmentSettings +from ..util import call_api, clean_kwargs, wait_for + +LOG = logging.getLogger(__name__) + + +@dataclass +class Farm: + id: str + + @staticmethod + def create( + *, + client: DeadlineClient, + display_name: str, + ) -> Farm: + response = call_api( + description=f"Create farm {display_name}", + fn=lambda: client.create_farm( + displayName=display_name, + ), + ) + farm_id = response["farmId"] + LOG.info(f"Created farm: {farm_id}") + return Farm(id=farm_id) + + def delete(self, *, client: DeadlineClient) -> None: + call_api( + description=f"Delete farm {self.id}", + fn=lambda: client.delete_farm(farmId=self.id), + ) + + +@dataclass +class Queue: + id: str + farm: Farm + + @staticmethod + def create( + *, + client: DeadlineClient, + display_name: str, + farm: Farm, + role_arn: str | None = None, + job_attachments: JobAttachmentSettings | None = None, + ) -> Queue: + kwargs = clean_kwargs( + { + "displayName": display_name, + "farmId": farm.id, + "roleArn": role_arn, + "jobAttachmentSettings": ( + job_attachments.as_queue_settings() if job_attachments else None + ), + } + ) + + response = call_api( + description=f"Create queue {display_name} in farm {farm.id}", + fn=lambda: client.create_queue(**kwargs), + ) + + queue_id = response["queueId"] + LOG.info(f"Created queue: {queue_id}") + return Queue( + id=queue_id, + farm=farm, + ) + + def delete(self, *, client: DeadlineClient) -> None: + call_api( + description=f"Delete queue {self.id}", + fn=lambda: client.delete_queue(queueId=self.id, farmId=self.farm.id), + ) + + +@dataclass +class Fleet: + id: str + farm: Farm + + @staticmethod + def create( + *, + client: DeadlineClient, + display_name: str, + farm: Farm, + configuration: dict, + role_arn: str | None = None, + ) -> Fleet: + kwargs = clean_kwargs( + { + "farmId": farm.id, + "displayName": display_name, + "roleArn": role_arn, + "configuration": configuration, + } + ) + response = call_api( + fn=lambda: client.create_fleet(**kwargs), + description=f"Create fleet {display_name} in farm {farm.id}", + ) + fleet_id = response["fleetId"] + LOG.info(f"Created fleet: {fleet_id}") + fleet = Fleet( + id=fleet_id, + farm=farm, + ) + + fleet.wait_for_desired_status( + client=client, + desired_status="ACTIVE", + allowed_statuses=set(["CREATE_IN_PROGRESS"]), + ) + + return fleet + + def delete(self, *, client: DeadlineClient) -> None: + call_api( + description=f"Delete fleet {self.id}", + fn=lambda: client.delete_fleet( + farmId=self.farm.id, + fleetId=self.id, + ), + ) + + def wait_for_desired_status( + self, + *, + client: DeadlineClient, + desired_status: str, + allowed_statuses: set[str] = set(), + interval_s: int = 10, + max_retries: int = 6, + ) -> None: + valid_statuses = set([desired_status]).union(allowed_statuses) + + # Temporary until we have waiters + def is_fleet_desired_status() -> bool: + response = call_api( + description=f"Get fleet {self.id}", + fn=lambda: client.get_fleet(fleetId=self.id, farmId=self.farm.id), + ) + fleet_status = response["status"] + + if fleet_status not in valid_statuses: + raise ValueError( + f"fleet entered a nonvalid status ({fleet_status}) while " + f"waiting for the desired status: {desired_status}" + ) + + return fleet_status == desired_status + + wait_for( + description=f"fleet {self.id} to reach desired status {desired_status}", + predicate=is_fleet_desired_status, + interval_s=interval_s, + max_retries=max_retries, + ) + + +@dataclass +class QueueFleetAssociation: + farm: Farm + queue: Queue + fleet: Fleet + + @staticmethod + def create( + *, + client: DeadlineClient, + farm: Farm, + queue: Queue, + fleet: Fleet, + ) -> QueueFleetAssociation: + call_api( + description=f"Create queue-fleet association for queue {queue.id} and fleet {fleet.id} in farm {farm.id}", + fn=lambda: client.create_queue_fleet_association( + farmId=farm.id, + queueId=queue.id, + fleetId=fleet.id, + ), + ) + return QueueFleetAssociation( + farm=farm, + queue=queue, + fleet=fleet, + ) + + def delete( + self, + *, + client: DeadlineClient, + stop_mode: Literal[ + "STOP_SCHEDULING_AND_CANCEL_TASKS", "STOP_SCHEDULING_AND_FINISH_TASKS" + ] = "STOP_SCHEDULING_AND_CANCEL_TASKS", + ) -> None: + self.stop(client=client, stop_mode=stop_mode) + call_api( + description=f"Delete queue-fleet association for queue {self.queue.id} and fleet {self.fleet.id} in farm {self.farm.id}", + fn=lambda: client.delete_queue_fleet_association( + farmId=self.farm.id, + queueId=self.queue.id, + fleetId=self.fleet.id, + ), + ) + + def stop( + self, + *, + client: DeadlineClient, + stop_mode: Literal[ + "STOP_SCHEDULING_AND_CANCEL_TASKS", "STOP_SCHEDULING_AND_FINISH_TASKS" + ] = "STOP_SCHEDULING_AND_CANCEL_TASKS", + interval_s: int = 10, + max_retries: int = 6, + ) -> None: + call_api( + description=f"Set queue-fleet association to STOPPING_SCHEDULING_AND_CANCELING_TASKS for queue {self.queue.id} and fleet {self.fleet.id}", + fn=lambda: client.update_queue_fleet_association( + farmId=self.farm.id, + queueId=self.queue.id, + fleetId=self.fleet.id, + status=stop_mode, + ), + ) + + # Temporary until we have waiters + valid_statuses = set(["STOPPED", stop_mode]) + + def is_qfa_in_desired_status() -> bool: + response = call_api( + description=f"Get queue-fleet association for queue {self.queue.id} and fleet {self.fleet.id}", + fn=lambda: client.get_queue_fleet_association( + farmId=self.farm.id, + queueId=self.queue.id, + fleetId=self.fleet.id, + ), + ) + + qfa_status = response["status"] + if qfa_status not in valid_statuses: + raise ValueError( + f"Association entered a nonvalid status ({qfa_status}) while " + "waiting for the desired status: STOPPED" + ) + + return qfa_status == "STOPPED" + + wait_for( + description="queue-fleet association to reach desired status STOPPED", + predicate=is_qfa_in_desired_status, + interval_s=interval_s, + max_retries=max_retries, + ) + + +class StrEnum(str, Enum): + pass + + +class TaskStatus(StrEnum): + UNKNOWN = "UNKNOWN" + PENDING = "PENDING" + READY = "READY" + RUNNING = "RUNNING" + ASSIGNED = "ASSIGNED" + SCHEDULED = "SCHEDULED" + INTERRUPTING = "INTERRUPTING" + SUSPENDED = "SUSPENDED" + CANCELED = "CANCELED" + FAILED = "FAILED" + SUCCEEDED = "SUCCEEDED" + + +COMPLETE_TASK_STATUSES = set( + ( + TaskStatus.CANCELED, + TaskStatus.FAILED, + TaskStatus.SUCCEEDED, + ) +) + + +@dataclass +class Job: + id: str + farm: Farm + queue: Queue + template: dict + + name: str + lifecycle_status: str + lifecycle_status_message: str + priority: int + created_at: datetime.datetime + created_by: str + + updated_at: datetime.datetime | None = None + updated_by: str | None = None + started_at: datetime.datetime | None = None + ended_at: datetime.datetime | None = None + task_run_status: TaskStatus | None = None + target_task_run_status: TaskStatus | None = None + task_run_status_counts: dict[TaskStatus, int] | None = None + storage_profile_id: str | None = None + max_failed_tasks_count: int | None = None + max_retries_per_task: int | None = None + parameters: dict | None = None + attachments: dict | None = None + description: str | None = None + + @staticmethod + def submit( + *, + client: DeadlineClient, + farm: Farm, + queue: Queue, + template: dict, + priority: int, + parameters: dict | None = None, + attachments: dict | None = None, + target_task_run_status: str | None = None, + max_failed_tasks_count: int | None = None, + max_retries_per_task: int | None = None, + ) -> Job: + kwargs = clean_kwargs( + { + "farmId": farm.id, + "queueId": queue.id, + "template": json.dumps(template), + "templateType": "JSON", + "priority": priority, + "parameters": parameters, + "attachments": attachments, + "targetTaskRunStatus": target_task_run_status, + "maxFailedTasksCount": max_failed_tasks_count, + "maxRetriesPerTask": max_retries_per_task, + } + ) + create_job_response = call_api( + description=f"Create job in farm {farm.id} and queue {queue.id}", + fn=lambda: client.create_job(**kwargs), + ) + job_id = create_job_response["jobId"] + LOG.info(f"Created job: {job_id}") + + job_details = Job.get_job_details( + client=client, + farm=farm, + queue=queue, + job_id=job_id, + ) + + return Job( + farm=farm, + queue=queue, + template=template, + **job_details, + ) + + @staticmethod + def get_job_details( + *, + client: DeadlineClient, + farm: Farm, + queue: Queue, + job_id: str, + ) -> dict[str, Any]: + """ + Calls GetJob API and returns the parsed response, which can be used as + keyword arguments to create/update this class. + """ + response = call_api( + description=f"Fetching job details for job {job_id}", + fn=lambda: client.get_job( + farmId=farm.id, + queueId=queue.id, + jobId=job_id, + ), + ) + + def get_optional_field( + name: str, + *, + default: Any = None, + transform: Callable[[Any], Any] | None = None, + ): + if name not in response: + return default + return transform(response[name]) if transform else response[name] + + return { + "id": response["jobId"], + "name": response["name"], + "lifecycle_status": response["lifecycleStatus"], + "lifecycle_status_message": response["lifecycleStatusMessage"], + "priority": response["priority"], + "created_at": response["createdAt"], + "created_by": response["createdBy"], + "updated_at": get_optional_field("updatedAt"), + "updated_by": get_optional_field("updatedBy"), + "started_at": get_optional_field("startedAt"), + "ended_at": get_optional_field("endedAt"), + "task_run_status": get_optional_field( + "taskRunStatus", + transform=lambda trs: TaskStatus[trs], + ), + "target_task_run_status": get_optional_field( + "targetTaskRunStatus", + transform=lambda trs: TaskStatus[trs], + ), + "task_run_status_counts": get_optional_field( + "taskRunStatusCounts", + transform=lambda trsc: {TaskStatus[k]: v for k, v in trsc.items()}, + ), + "storage_profile_id": get_optional_field("storageProfileId"), + "max_failed_tasks_count": get_optional_field("maxFailedTasksCount"), + "max_retries_per_task": get_optional_field("maxRetriesPerTask"), + "parameters": get_optional_field("parameters"), + "attachments": get_optional_field("attachments"), + "description": get_optional_field("description"), + } + + def refresh_job_info(self, *, client: DeadlineClient) -> None: + """ + Calls GetJob API to refresh job information. The result is used to update the fields + of this class. + """ + kwargs = Job.get_job_details( + client=client, + farm=self.farm, + queue=self.queue, + job_id=self.id, + ) + all_field_names = set([f.name for f in fields(self)]) + assert all(k in all_field_names for k in kwargs) + for k, v in kwargs.items(): + object.__setattr__(self, k, v) + + def update( + self, + *, + client: DeadlineClient, + priority: int | None = None, + target_task_run_status: str | None = None, + max_failed_tasks_count: int | None = None, + max_retries_per_task: int | None = None, + ) -> None: + kwargs = clean_kwargs( + { + "priority": priority, + "targetTaskRunStatus": target_task_run_status, + "maxFailedTasksCount": max_failed_tasks_count, + "maxRetriesPerTask": max_retries_per_task, + } + ) + call_api( + description=f"Update job in farm {self.farm.id} and queue {self.queue.id} with kwargs {kwargs}", + fn=lambda: client.update_job( + farmId=self.farm.id, + queueId=self.queue.id, + jobId=self.id, + **kwargs, + ), + ) + + def wait_until_complete( + self, + *, + client: DeadlineClient, + wait_interval_sec: int = 10, + max_retries: int | None = None, + ) -> None: + """ + Waits until the job is complete. + This method will refresh the job info until the job is complete or the operation times out. + + Args: + wait_interval_sec (int, optional): Interval between waits in seconds. Defaults to 5. + max_retries (int, optional): Maximum retry count. Defaults to None. + """ + + def _is_job_complete(): + self.refresh_job_info(client=client) + if not self.complete: + LOG.info(f"Job {self.id} not complete") + return self.complete + + wait_for( + description=f"job {self.id} to complete", + predicate=_is_job_complete, + interval_s=wait_interval_sec, + max_retries=max_retries, + ) + + @property + def complete(self) -> bool: # pragma: no cover + return self.task_run_status in COMPLETE_TASK_STATUSES + + def __str__(self) -> str: # pragma: no cover + if self.task_run_status_counts: + task_run_status_counts = "\n".join( + [ + f"\t{k}: {v}" + for k, v in sorted( + filter(lambda i: i[1] > 0, self.task_run_status_counts.items()), + key=lambda i: i[1], + reverse=True, + ) + ] + ) + else: + task_run_status_counts = str(self.task_run_status_counts) + + return "\n".join( + [ + "Job:", + f"id: {self.id}", + f"name: {self.name}", + f"description: {self.description}", + f"farm: {self.farm.id}", + f"queue: {self.queue.id}", + f"template: {json.dumps(self.template)}", + f"parameters: {self.parameters}", + f"attachments: {self.attachments}", + f"lifecycle_status: {self.lifecycle_status}", + f"lifecycle_status_message: {self.lifecycle_status_message}", + f"priority: {self.priority}", + f"target_task_run_status: {self.target_task_run_status}", + f"task_run_status: {self.task_run_status}", + f"task_run_status_counts:\n{task_run_status_counts}", + f"storage_profile_id: {self.storage_profile_id}", + f"max_failed_tasks_count: {self.max_failed_tasks_count}", + f"max_retries_per_task: {self.max_retries_per_task}", + f"created_at: {self.created_at}", + f"created_by: {self.created_by}", + f"updated_at: {self.updated_at}", + f"updated_by: {self.updated_by}", + f"started_at: {self.started_at}", + f"ended_at: {self.ended_at}", + ] + ) diff --git a/src/deadline_test_scaffolding/deadline_stub.py b/src/deadline_test_scaffolding/deadline/stubs.py similarity index 94% rename from src/deadline_test_scaffolding/deadline_stub.py rename to src/deadline_test_scaffolding/deadline/stubs.py index c75bb6b..83188c6 100644 --- a/src/deadline_test_scaffolding/deadline_stub.py +++ b/src/deadline_test_scaffolding/deadline/stubs.py @@ -3,7 +3,6 @@ import dataclasses from dataclasses import dataclass from typing import Optional -from deadline_test_scaffolding.constants import JOB_ATTACHMENTS_ROOT_PREFIX from botocore.exceptions import ClientError as OriginalClientError @@ -83,7 +82,7 @@ def get_queue(self, *, farmId: str, queueId: str) -> dict: "fleets": [], "jobAttachmentSettings": { "s3BucketName": self.job_attachments_bucket_name, - "rootPrefix": JOB_ATTACHMENTS_ROOT_PREFIX, + "rootPrefix": "root", }, } diff --git a/src/deadline_test_scaffolding/deadline/worker.py b/src/deadline_test_scaffolding/deadline/worker.py new file mode 100644 index 0000000..882dc83 --- /dev/null +++ b/src/deadline_test_scaffolding/deadline/worker.py @@ -0,0 +1,637 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from __future__ import annotations + +import abc +import botocore.client +import botocore.exceptions +import glob +import json +import logging +import os +import pathlib +import posixpath +import re +import shutil +import subprocess +import tempfile +import time +from dataclasses import dataclass, field, InitVar, replace +from typing import Any, ClassVar, Optional, cast + +from .client import DeadlineClient +from ..models import CodeArtifactRepositoryInfo, ServiceModel +from ..util import call_api, wait_for + +LOG = logging.getLogger(__name__) + +# Hardcoded to default posix path for worker.json file which has the worker ID in it +WORKER_JSON_PATH = "/var/lib/deadline/worker.json" +DOCKER_CONTEXT_DIR = os.path.join(os.path.dirname(__file__), "..", "containers", "worker") + + +def configure_worker_command(*, config: DeadlineWorkerConfiguration) -> str: # pragma: no cover + """Get the command to configure the Worker. This must be run as root.""" + cmds = [ + config.worker_agent_install.install_command, + # fmt: off + ( + "install-deadline-worker " + + "-y " + + f"--farm-id {config.farm_id} " + + f"--fleet-id {config.fleet_id} " + + f"--region {config.region} " + + f"--user {config.user} " + + f"--group {config.group} " + + f"{'--allow-shutdown ' if config.allow_shutdown else ''}" + + f"{'--no-install-service ' if config.no_install_service else ''}" + ), + # fmt: on + ] + + if config.service_model: + cmds.append( + f"runuser -l {config.user} -s /bin/bash -c '{config.service_model.install_command}'" + ) + + return " && ".join(cmds) + + +@dataclass(frozen=True) +class PipInstall: # pragma: no cover + requirement_specifiers: list[str] + """See https://peps.python.org/pep-0508/""" + upgrade_pip: bool = True + find_links: list[str] | None = None + no_deps: bool = False + force_reinstall: bool = False + codeartifact: CodeArtifactRepositoryInfo | None = None + + def __post_init__(self) -> None: + assert len( + self.requirement_specifiers + ), "At least one requirement specifier is required, but got 0" + + @property + def install_args(self) -> list[str]: + args = [] + if self.find_links: + args.append(f"--find-links={','.join(self.find_links)}") + if self.no_deps: + args.append("--no-deps") + if self.force_reinstall: + args.append("--force-reinstall") + return args + + @property + def install_command(self) -> str: + cmds = [] + + if self.codeartifact: + cmds.append( + "aws codeartifact login --tool pip " + + f"--domain {self.codeartifact.domain} " + + f"--domain-owner {self.codeartifact.domain_owner} " + + f"--repository {self.codeartifact.repository} " + ) + + if self.upgrade_pip: + cmds.append("pip install --upgrade pip") + + cmds.append( + " ".join( + [ + "pip", + "install", + *self.install_args, + *self.requirement_specifiers, + ] + ) + ) + + return " && ".join(cmds) + + +class DeadlineWorker(abc.ABC): + @abc.abstractmethod + def start(self) -> None: + pass + + @abc.abstractmethod + def stop(self) -> None: + pass + + @abc.abstractmethod + def send_command(self, command: str) -> CommandResult: + pass + + @abc.abstractproperty + def worker_id(self) -> str: + pass + + +@dataclass(frozen=True) +class CommandResult: # pragma: no cover + exit_code: int + stdout: str + stderr: Optional[str] = None + + def __str__(self) -> str: + return "\n".join( + [ + f"exit_code: {self.exit_code}", + "", + "================================", + "========= BEGIN stdout =========", + "================================", + "", + self.stdout, + "", + "==============================", + "========= END stdout =========", + "==============================", + "", + "================================", + "========= BEGIN stderr =========", + "================================", + "", + str(self.stderr), + "", + "==============================", + "========= END stderr =========", + "==============================", + ] + ) + + +@dataclass(frozen=True) +class DeadlineWorkerConfiguration: + farm_id: str + fleet_id: str + region: str + user: str + group: str + allow_shutdown: bool + worker_agent_install: PipInstall + no_install_service: bool = False + service_model: ServiceModel | None = None + file_mappings: list[tuple[str, str]] | None = None + """Mapping of files to copy from host environment to worker environment""" + + +@dataclass +class EC2InstanceWorker(DeadlineWorker): + AL2023_AMI_NAME: ClassVar[str] = "al2023-ami-kernel-6.1-x86_64" + + subnet_id: str + security_group_id: str + instance_profile_name: str + bootstrap_bucket_name: str + s3_client: botocore.client.BaseClient + ec2_client: botocore.client.BaseClient + ssm_client: botocore.client.BaseClient + deadline_client: DeadlineClient + configuration: DeadlineWorkerConfiguration + + instance_id: Optional[str] = field(init=False, default=None) + + override_ami_id: InitVar[Optional[str]] = None + """ + Option to override the AMI ID for the EC2 instance. The latest AL2023 is used by default. + Note that the scripting to configure the EC2 instance is only verified to work on AL2023. + """ + + def __post_init__(self, override_ami_id: Optional[str] = None): + if override_ami_id: + self._ami_id = override_ami_id + + def start(self) -> None: + s3_files = self._stage_s3_bucket() + self._launch_instance(s3_files=s3_files) + self._start_worker_agent() + + def stop(self) -> None: + LOG.info(f"Terminating EC2 instance {self.instance_id}") + self.ec2_client.terminate_instances(InstanceIds=[self.instance_id]) + self.instance_id = None + + def send_command(self, command: str) -> CommandResult: + """Send a command via SSM to a shell on a launched EC2 instance. Once the command has fully + finished the result of the invocation is returned. + """ + ssm_waiter = self.ssm_client.get_waiter("command_executed") + + # To successfully send an SSM Command to an instance the instance must: + # 1) Be in RUNNING state; + # 2) Have the AWS Systems Manager (SSM) Agent running; and + # 3) Have had enough time for the SSM Agent to connect to System's Manager + # + # If we send an SSM command then we will get an InvalidInstanceId error + # if the instance isn't in that state. + NUM_RETRIES = 10 + SLEEP_INTERVAL_S = 5 + for i in range(0, NUM_RETRIES): + LOG.info(f"Sending SSM command to instance {self.instance_id}") + try: + send_command_response = self.ssm_client.send_command( + InstanceIds=[self.instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [command]}, + ) + # Successfully sent. Bail out of the loop. + break + except botocore.exceptions.ClientError as error: + error_code = error.response["Error"]["Code"] + if error_code == "InvalidInstanceId" and i < NUM_RETRIES - 1: + LOG.warning( + f"Instance {self.instance_id} is not ready for SSM command (received InvalidInstanceId error). Retrying in {SLEEP_INTERVAL_S}s." + ) + time.sleep(SLEEP_INTERVAL_S) + continue + raise + + command_id = send_command_response["Command"]["CommandId"] + + LOG.info(f"Waiting for SSM command {command_id} to reach a terminal state") + try: + ssm_waiter.wait( + InstanceId=self.instance_id, + CommandId=command_id, + ) + except botocore.exceptions.WaiterError: # pragma: no cover + # Swallow exception, we're going to check the result anyway + pass + + ssm_command_result = self.ssm_client.get_command_invocation( + InstanceId=self.instance_id, + CommandId=command_id, + ) + result = CommandResult( + exit_code=ssm_command_result["ResponseCode"], + stdout=ssm_command_result["StandardOutputContent"], + stderr=ssm_command_result["StandardErrorContent"], + ) + if result.exit_code == -1: # pragma: no cover + # Response code of -1 in a terminal state means the command was not received by the node + LOG.error(f"Failed to send SSM command {command_id} to {self.instance_id}: {result}") + + LOG.info(f"SSM command {command_id} completed with exit code: {result.exit_code}") + return result + + def _stage_s3_bucket(self) -> list[tuple[str, str]] | None: + """Stages file_mappings to an S3 bucket and returns the mapping of S3 URI to dest path""" + if not self.configuration.file_mappings: + LOG.info("No file mappings to stage to S3") + return None + + s3_to_src_mapping: dict[str, str] = {} + s3_to_dst_mapping: dict[str, str] = {} + for src_glob, dst in self.configuration.file_mappings: + for src_file in glob.glob(src_glob): + s3_key = f"worker/{os.path.basename(src_file)}" + assert s3_key not in s3_to_src_mapping, ( + "Duplicate S3 keys generated for file mappings. All source files must have unique " + + f"filenames. Mapping: {self.configuration.file_mappings}" + ) + s3_to_src_mapping[s3_key] = src_file + s3_to_dst_mapping[f"s3://{self.bootstrap_bucket_name}/{s3_key}"] = dst + + for key, local_path in s3_to_src_mapping.items(): + LOG.info(f"Uploading file {local_path} to s3://{self.bootstrap_bucket_name}/{key}") + try: + # self.s3_client.upload_file(local_path, self.bootstrap_bucket_name, key) + with open(local_path, mode="rb") as f: + self.s3_client.put_object( + Bucket=self.bootstrap_bucket_name, + Key=key, + Body=f, + ) + except botocore.exceptions.ClientError as e: + LOG.exception( + f"Failed to upload file {local_path} to s3://{self.bootstrap_bucket_name}/{key}: {e}" + ) + raise + + return list(s3_to_dst_mapping.items()) + + def _launch_instance(self, *, s3_files: list[tuple[str, str]] | None = None) -> None: + assert ( + not self.instance_id + ), "Attempted to launch EC2 instance when one was already launched" + + copy_s3_command = "" + if s3_files: + copy_s3_command = " && ".join( + [ + f"aws s3 cp {s3_uri} {dst} && chown {self.configuration.user} {dst}" + for s3_uri, dst in s3_files + ] + ) + + LOG.info("Launching EC2 instance") + run_instance_response = self.ec2_client.run_instances( + MinCount=1, + MaxCount=1, + ImageId=self.ami_id, + InstanceType="t3.micro", + IamInstanceProfile={"Name": self.instance_profile_name}, + SubnetId=self.subnet_id, + SecurityGroupIds=[self.security_group_id], + MetadataOptions={"HttpTokens": "required", "HttpEndpoint": "enabled"}, + TagSpecifications=[ + { + "ResourceType": "instance", + "Tags": [ + { + "Key": "InstanceIdentification", + "Value": "DeadlineScaffoldingWorker", + } + ], + } + ], + UserData=f"""#!/bin/bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +set -x +groupadd --system {self.configuration.group} +useradd --create-home --system --shell=/bin/bash --groups={self.configuration.group} jobuser +useradd --create-home --system --shell=/bin/bash --groups={self.configuration.group} {self.configuration.user} +{copy_s3_command} + +runuser --login {self.configuration.user} --command 'python3 -m venv $HOME/.venv && echo ". $HOME/.venv/bin/activate" >> $HOME/.bashrc' +""", + ) + + self.instance_id = run_instance_response["Instances"][0]["InstanceId"] + LOG.info(f"Launched EC2 instance {self.instance_id}") + + LOG.info(f"Waiting for EC2 instance {self.instance_id} status to be OK") + instance_running_waiter = self.ec2_client.get_waiter("instance_status_ok") + instance_running_waiter.wait(InstanceIds=[self.instance_id]) + LOG.info(f"EC2 instance {self.instance_id} status is OK") + + def _start_worker_agent(self) -> None: # pragma: no cover + assert self.instance_id + + LOG.info(f"Sending SSM command to configure Worker agent on instance {self.instance_id}") + cmd_result = self.send_command( + f"cd /home/{self.configuration.user}; . .venv/bin/activate; {configure_worker_command(config=self.configuration)}" + ) + assert cmd_result.exit_code == 0, f"Failed to configure Worker agent: {cmd_result}" + LOG.info("Successfully configured Worker agent") + + LOG.info(f"Sending SSM command to start Worker agent on instance {self.instance_id}") + cmd_result = self.send_command( + " && ".join( + [ + f"nohup runuser --login {self.configuration.user} -c 'AWS_DEFAULT_REGION={self.configuration.region} deadline-worker-agent --allow-instance-profile >/dev/null 2>&1 &'", + # Verify Worker is still running + "echo Waiting 5s for agent to get started", + "sleep 5", + "echo 'Running pgrep to see if deadline-worker-agent is running'", + f"pgrep --count --full -u {self.configuration.user} deadline-worker-agent", + ] + ), + ) + assert cmd_result.exit_code == 0, f"Failed to start Worker agent: {cmd_result}" + LOG.info("Successfully started Worker agent") + + @property + def worker_id(self) -> str: + cmd_result = self.send_command("cat /var/lib/deadline/worker.json | jq -r '.worker_id'") + assert cmd_result.exit_code == 0, f"Failed to get Worker ID: {cmd_result}" + + worker_id = cmd_result.stdout.rstrip("\n\r") + assert re.match( + r"^worker-[0-9a-f]{32}$", worker_id + ), f"Got nonvalid Worker ID from command stdout: {cmd_result}" + return worker_id + + @property + def ami_id(self) -> str: + if not hasattr(self, "_ami_id"): + # Grab the latest AL2023 AMI + # https://aws.amazon.com/blogs/compute/query-for-the-latest-amazon-linux-ami-ids-using-aws-systems-manager-parameter-store/ + ssm_param_name = ( + f"/aws/service/ami-amazon-linux-latest/{EC2InstanceWorker.AL2023_AMI_NAME}" + ) + response = call_api( + description=f"Getting latest AL2023 AMI ID from SSM parameter {ssm_param_name}", + fn=lambda: self.ssm_client.get_parameters(Names=[ssm_param_name]), + ) + + parameters = response.get("Parameters", []) + assert ( + len(parameters) == 1 + ), f"Received incorrect number of SSM parameters. Expected 1, got response: {response}" + self._ami_id = parameters[0]["Value"] + LOG.info(f"Using latest AL2023 AMI {self._ami_id}") + + return self._ami_id + + +@dataclass +class DockerContainerWorker(DeadlineWorker): + configuration: DeadlineWorkerConfiguration + + _container_id: Optional[str] = field(init=False, default=None) + + def __post_init__(self) -> None: + # Do not install Worker agent service since it's recommended to avoid systemd usage on Docker containers + self.configuration = replace(self.configuration, no_install_service=True) + + def start(self) -> None: + self._tmpdir = pathlib.Path(tempfile.mkdtemp()) + + # Environment variables for "run_container.sh" + run_container_env = { + **os.environ, + "AGENT_USER": self.configuration.user, + "SHARED_GROUP": self.configuration.group, + "JOB_USER": "jobuser", + "CONFIGURE_WORKER_AGENT_CMD": configure_worker_command( + config=self.configuration, + ), + } + + LOG.info(f"Staging Docker build context directory {str(self._tmpdir)}") + shutil.copytree(DOCKER_CONTEXT_DIR, str(self._tmpdir), dirs_exist_ok=True) + + if self.configuration.file_mappings: + # Stage a special dir with files to copy over to a temp folder in the Docker container + # The container is responsible for copying files from that temp folder into the final destinations + file_mappings_dir = self._tmpdir / "file_mappings" + os.makedirs(str(file_mappings_dir)) + + # Mapping of files in temp Docker container folder to their final destination + docker_file_mappings: dict[str, str] = {} + for src, dst in self.configuration.file_mappings: + src_file_name = os.path.basename(src) + + # The Dockerfile copies the file_mappings dir in the build context to "/file_mappings" in the container + # Build up an array of mappings from "/file_mappings" to their final destination + src_docker_path = posixpath.join("/file_mappings", src_file_name) + assert src_docker_path not in docker_file_mappings, ( + "Duplicate paths generated for file mappings. All source files must have unique " + + f"filenames. Mapping: {self.configuration.file_mappings}" + ) + docker_file_mappings[src_docker_path] = dst + + # Copy the file over to the stage directory + shutil.copyfile(src, str(file_mappings_dir / src_file_name)) + + run_container_env["FILE_MAPPINGS"] = json.dumps(docker_file_mappings) + + # Build and start the container + LOG.info("Starting Docker container") + try: + proc = subprocess.Popen( + args="./run_container.sh", + cwd=str(self._tmpdir), + env=run_container_env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + encoding="utf-8", + ) + + # Live logging of Docker build + assert proc.stdout + with proc.stdout: + for line in iter(proc.stdout.readline, ""): + LOG.info(line.rstrip("\r\n")) + except Exception as e: # pragma: no cover + LOG.exception(f"Failed to start Worker agent Docker container: {e}") + _handle_subprocess_error(e) + raise + else: + exit_code = proc.wait(timeout=60) + assert exit_code == 0, f"Process failed with exit code {exit_code}" + + # Grab the container ID from --cidfile + try: + self._container_id = subprocess.check_output( + args=["cat", ".container_id"], + cwd=str(self._tmpdir), + text=True, + encoding="utf-8", + timeout=1, + ).rstrip("\r\n") + except Exception as e: # pragma: no cover + LOG.exception(f"Failed to get Docker container ID: {e}") + _handle_subprocess_error(e) + raise + else: + LOG.info(f"Started Docker container {self._container_id}") + + def stop(self) -> None: + assert ( + self._container_id + ), "Cannot stop Docker container: Container ID is not set. Has the Docker container been started yet?" + + LOG.info(f"Terminating Worker agent process in Docker container {self._container_id}") + try: + self.send_command(f"pkill --signal term -f {self.configuration.user}") + except Exception as e: # pragma: no cover + LOG.exception(f"Failed to terminate Worker agent process: {e}") + raise + else: + LOG.info("Worker agent process terminated") + + LOG.info(f"Stopping Docker container {self._container_id}") + try: + subprocess.check_output( + args=["docker", "container", "stop", self._container_id], + cwd=str(self._tmpdir), + text=True, + encoding="utf-8", + timeout=30, + ) + except Exception as e: # pragma: noc over + LOG.exception(f"Failed to stop Docker container {self._container_id}: {e}") + _handle_subprocess_error(e) + raise + else: + LOG.info(f"Stopped Docker container {self._container_id}") + self._container_id = None + + def send_command(self, command: str, *, quiet: bool = False) -> CommandResult: + assert ( + self._container_id + ), "Container ID not set. Has the Docker container been started yet?" + + if not quiet: # pragma: no cover + LOG.info(f"Sending command '{command}' to Docker container {self._container_id}") + try: + result = subprocess.run( + args=[ + "docker", + "exec", + self._container_id, + "/bin/bash", + "-euo", + "pipefail", + "-c", + command, + ], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + encoding="utf-8", + ) + except Exception as e: + if not quiet: # pragma: no cover + LOG.exception(f"Failed to run command: {e}") + _handle_subprocess_error(e) + raise + else: + return CommandResult( + exit_code=result.returncode, + stdout=result.stdout, + stderr=result.stderr, + ) + + @property + def worker_id(self) -> str: + cmd_result: Optional[CommandResult] = None + + def got_worker_id() -> bool: + nonlocal cmd_result + try: + cmd_result = self.send_command( + "cat /var/lib/deadline/worker.json | jq -r '.worker_id'", + quiet=True, + ) + except subprocess.CalledProcessError as e: + LOG.warning(f"Worker ID retrieval failed: {e}") + return False + else: + return cmd_result.exit_code == 0 + + wait_for( + description="retrieval of worker ID from /var/lib/deadline/worker.json", + predicate=got_worker_id, + interval_s=10, + max_retries=6, + ) + + assert isinstance(cmd_result, CommandResult) + cmd_result = cast(CommandResult, cmd_result) + assert cmd_result.exit_code == 0, f"Failed to get Worker ID: {cmd_result}" + + worker_id = cmd_result.stdout.rstrip("\r\n") + assert re.match( + r"^worker-[0-9a-f]{32}$", worker_id + ), f"Got nonvalid Worker ID from command stdout: {cmd_result}" + + return worker_id + + @property + def container_id(self) -> str | None: + return self._container_id + + +def _handle_subprocess_error(e: Any) -> None: # pragma: no cover + if hasattr(e, "stdout"): + LOG.error(f"Command stdout: {e.stdout}") + if hasattr(e, "stderr"): + LOG.error(f"Command stderr: {e.stderr}") diff --git a/src/deadline_test_scaffolding/deadline_manager.py b/src/deadline_test_scaffolding/deadline_manager.py deleted file mode 100644 index d46be13..0000000 --- a/src/deadline_test_scaffolding/deadline_manager.py +++ /dev/null @@ -1,571 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -from __future__ import annotations - -import os -import posixpath -import sys -import tempfile -import uuid -from time import sleep -from typing import Any, Dict, Optional, List - -import boto3 -from botocore.client import BaseClient -from botocore.exceptions import ClientError -from botocore.loaders import Loader -from botocore.model import ServiceModel, OperationModel - -from .constants import ( - JOB_ATTACHMENTS_ROOT_PREFIX, - DEFAULT_CMF_CONFIG, -) - - -class DeadlineManager: - """This class is responsible for setting up and tearing down the required components - for the tests to be run.""" - - deadline_service_model_bucket: Optional[str] = None - deadline_endpoint: Optional[str] = None - - kms_client: BaseClient - kms_key_metadata: Optional[Dict[str, Any]] - - deadline_client: DeadlineClient - farm_id: Optional[str] - queue_id: Optional[str] - fleet_id: Optional[str] - job_attachment_bucket: Optional[str] - additional_queues: list[dict[str, Any]] - deadline_model_dir: Optional[tempfile.TemporaryDirectory] = None - - MOCKED_SERVICE_VERSION = "2020-08-21" - - def __init__(self, should_add_deadline_models: bool = False) -> None: - """ - Initializing the Deadline Manager - """ - self.deadline_service_model_bucket = os.getenv("DEADLINE_SERVICE_MODEL_BUCKET") - self.deadline_endpoint = os.getenv("DEADLINE_ENDPOINT") - - # Installing the deadline service models. - if should_add_deadline_models: - self.get_deadline_models() - - self.deadline_client = self._get_deadline_client(self.deadline_endpoint) - - # Create the KMS client - self.kms_client = boto3.client("kms") - - self.farm_id: Optional[str] = None - self.queue_id: Optional[str] = None - self.fleet_id: Optional[str] = None - self.additional_queues: list[dict[str, Any]] = [] - self.kms_key_metadata: Optional[dict[str, Any]] = None - - def get_deadline_models(self): - """ - This function will download and install the models for deadline so we can use the deadline - client. - """ - if self.deadline_service_model_bucket is None: - raise ValueError( - "Environment variable DEADLINE_SERVICE_MODEL_BUCKET is not set. " - "Unable to get deadline service model." - ) - - # Create the S3 client - s3_client: BaseClient = boto3.client("s3") - - # Create a temp directory to store the model file - self.deadline_model_dir = tempfile.TemporaryDirectory() - service_model_dir = posixpath.join( - self.deadline_model_dir.name, "deadline", self.MOCKED_SERVICE_VERSION - ) - os.makedirs(service_model_dir) - - # Downloading the deadline models. - s3_client.download_file( - self.deadline_service_model_bucket, - "service-2.json", - posixpath.join(service_model_dir, "service-2.json"), - ) - os.environ["AWS_DATA_PATH"] = self.deadline_model_dir.name - - def create_scaffolding( - self, - worker_role_arn: str, - job_attachments_bucket: str, - farm_name: str = uuid.uuid4().hex, - queue_name: str = uuid.uuid4().hex, - fleet_name: str = uuid.uuid4().hex, - ) -> None: - self.create_kms_key() - self.create_farm(farm_name) - self.create_queue(queue_name) - self.add_job_attachments_bucket(job_attachments_bucket) - self.create_fleet(fleet_name, worker_role_arn) - self.queue_fleet_association() - - def create_kms_key(self) -> None: - try: - response: Dict[str, Any] = self.kms_client.create_key( - Description="The KMS used for testing created by the " - "DeadlineClientSoftwareTestScaffolding.", - Tags=[{"TagKey": "Name", "TagValue": "DeadlineClientSoftwareTestScaffolding"}], - ) - except ClientError as e: - print("Failed to create CMK.", file=sys.stderr) - print(f"The following exception was raised: {e}", file=sys.stderr) - raise - else: - self.kms_key_metadata = response["KeyMetadata"] - - # We should always get a metadata when successful, this is for mypy. - if self.kms_key_metadata: # pragma: no cover - print(f"Created CMK with id = {self.kms_key_metadata['KeyId']}") - self.kms_client.enable_key(KeyId=self.kms_key_metadata["KeyId"]) - print(f"Enabled CMK with id = {self.kms_key_metadata['KeyId']}") - - def delete_kms_key(self) -> None: - if ( - not hasattr(self, "kms_key_metadata") - or self.kms_key_metadata is None - or "KeyId" not in self.kms_key_metadata - ): - raise Exception("ERROR: Attempting to delete a KMS key when None was created!") - - try: - # KMS keys by default are deleted in 30 days (this is their pending window). - # 7 days is the fastest we can clean them up. - pending_window = 7 - self.kms_client.schedule_key_deletion( - KeyId=self.kms_key_metadata["KeyId"], PendingWindowInDays=pending_window - ) - except ClientError as e: - print( - "Failed to schedule the deletion of CMK with id = " - f"{self.kms_key_metadata['KeyId']}", - file=sys.stderr, - ) - print(f"The following error was raised: {e}", file=sys.stderr) - raise - else: - print(f"Scheduled deletion of CMK with id = {self.kms_key_metadata['KeyId']}") - self.kms_key_metadata = None - - def create_farm(self, farm_name: str) -> None: - if ( - not hasattr(self, "kms_key_metadata") - or self.kms_key_metadata is None - or "Arn" not in self.kms_key_metadata - ): - raise Exception("ERROR: Attempting to create a farm without having creating a CMK.") - - try: - response = self.deadline_client.create_farm( - displayName=farm_name, kmsKeyArn=self.kms_key_metadata["Arn"] - ) - except ClientError as e: - print("Failed to create a farm.", file=sys.stderr) - print(f"The following exception was raised: {e}", file=sys.stderr) - raise - else: - self.farm_id = response["farmId"] - print(f"Successfully create farm with id = {self.farm_id}") - - def delete_farm(self) -> None: - if not hasattr(self, "farm_id") or not self.farm_id: - raise Exception("ERROR: Attempting to delete a farm without having created one.") - - try: - self.deadline_client.delete_farm(farmId=self.farm_id) - except ClientError as e: - print(f"Failed to delete farm with id = {self.farm_id}.", file=sys.stderr) - print(f"The following exception was raised: {e}", file=sys.stderr) - raise - else: - print(f"Successfully deleted farm with id = {self.farm_id}") - self.farm_id = None - - # TODO: Add support for queue users with jobsRunAs - def create_queue(self, queue_name: str) -> None: - if not hasattr(self, "farm_id") or self.farm_id is None: - raise Exception( - "ERROR: Attempting to create a queue without having had created a farm!" - ) - - try: - response = self.deadline_client.create_queue( - displayName=queue_name, - farmId=self.farm_id, - ) - except ClientError as e: - print(f"Failed to create queue with displayName = {queue_name}.", file=sys.stderr) - print(f"The following exception was raised: {e}", file=sys.stderr) - raise - else: - self.queue_id = response["queueId"] - print(f"Successfully created queue with id = {self.queue_id}") - - def add_job_attachments_bucket(self, job_attachments_bucket: str): - """Add a job attachments bucket to the queue""" - self.deadline_client.update_queue( - queueId=self.queue_id, - farmId=self.farm_id, - jobAttachmentSettings={ - "s3BucketName": job_attachments_bucket, - "rootPrefix": JOB_ATTACHMENTS_ROOT_PREFIX, - }, - ) - - def create_additional_queue(self, **kwargs) -> Dict[str, Any]: - """Create and add another queue to the deadline manager""" - input = {"farmId": self.farm_id} - input.update(kwargs) - response = self.deadline_client.create_queue(**input) - response = self.deadline_client.get_queue( - farmId=input["farmId"], queueId=response["queueId"] - ) - self.additional_queues.append(response) - return response - - def delete_queue(self) -> None: - if not hasattr(self, "farm_id") or not self.farm_id: - raise Exception( - "ERROR: Attempting to delete a queue without having had created a farm!" - ) - - if not hasattr(self, "queue_id") or not self.queue_id: - raise Exception("ERROR: Attempting to delete a queue without having had created one!") - - try: - self.deadline_client.delete_queue(queueId=self.queue_id, farmId=self.farm_id) - except ClientError as e: - print(f"Failed to delete queue with id = {self.queue_id}.", file=sys.stderr) - print(f"The following exception was raised: {e}", file=sys.stderr) - raise - else: - print(f"Successfully deleted queue with id = {self.queue_id}") - self.queue_id = None - - def delete_additional_queues(self) -> None: - """Delete all additional queues that have been added.""" - for queue in self.additional_queues: - try: - self.deadline_client.delete_queue(farmId=queue["farmId"], queueId=queue["queueId"]) - except Exception as e: - print(f"delete queue exception {str(e)}") - continue - - def create_fleet(self, fleet_name: str, worker_role_arn: str) -> None: - if not hasattr(self, "farm_id") or not self.farm_id: - raise Exception( - "ERROR: Attempting to create a fleet without having had created a farm!" - ) - try: - response = self.deadline_client.create_fleet( - farmId=self.farm_id, - displayName=fleet_name, - roleArn=worker_role_arn, - configuration=DEFAULT_CMF_CONFIG, - ) - except ClientError as e: - print(f"Failed to create fleet with displayName = {fleet_name}.", file=sys.stderr) - print(f"The following exception was raised: {e}", file=sys.stderr) - raise - else: - self.fleet_id = response["fleetId"] - self.wait_for_desired_fleet_status( - desired_status="ACTIVE", allowed_status=["ACTIVE", "CREATE_IN_PROGRESS"] - ) - print(f"Successfully created a fleet with id = {self.fleet_id}") - - # Temporary until we have waiters - def wait_for_desired_fleet_status(self, desired_status: str, allowed_status: List[str]) -> None: - max_retries = 10 - fleet_status = None - retry_count = 0 - while fleet_status != desired_status and retry_count < max_retries: - response = self.deadline_client.get_fleet(fleetId=self.fleet_id, farmId=self.farm_id) - - fleet_status = response["status"] - - if fleet_status not in allowed_status: - raise ValueError( - f"fleet entered a nonvalid status ({fleet_status}) while " - f"waiting for the desired status: {desired_status}." - ) - - if fleet_status == desired_status: - return response - - print(f"Fleet status: {fleet_status}\nChecking again...") - retry_count += 1 - sleep(10) - - raise ValueError( - f"Timed out waiting for fleet status to reach the desired status {desired_status}." - ) - - def queue_fleet_association(self) -> None: - if not hasattr(self, "farm_id") or not self.farm_id: - raise Exception("ERROR: Attempting to queue a fleet without having had created a farm!") - - if not hasattr(self, "queue_id") or not self.queue_id: - raise Exception("ERROR: Attempting to queue a fleet without creating a queue") - - if not hasattr(self, "fleet_id") or not self.fleet_id: - raise Exception("ERROR: Attempting to queue a fleet without having had created one!") - - try: - self.deadline_client.create_queue_fleet_association( - farmId=self.farm_id, queueId=self.queue_id, fleetId=self.fleet_id - ) - except ClientError as e: - print(f"Failed to associate fleet with id = {self.fleet_id}.", file=sys.stderr) - print(f"The following exception was raised: {e}", file=sys.stderr) - raise - else: - print(f"Successfully queued fleet with id = {self.fleet_id}") - - # Temporary until we have waiters - def stop_queue_fleet_associations_and_wait(self) -> None: - self.deadline_client.update_queue_fleet_association( - farmId=self.farm_id, - queueId=self.queue_id, - fleetId=self.fleet_id, - status="CANCEL_WORK", - ) - max_retries = 10 - retry_count = 0 - qfa_status = None - allowed_status = ["STOPPED", "CANCEL_WORK"] - while qfa_status != "STOPPED" and retry_count < max_retries: - response = self.deadline_client.get_queue_fleet_association( - farmId=self.farm_id, queueId=self.queue_id, fleetId=self.fleet_id - ) - - qfa_status = response["status"] - - if qfa_status not in allowed_status: - raise ValueError( - f"Association entered a nonvalid status ({qfa_status}) while " - f"waiting for the desired status: STOPPED" - ) - - if qfa_status == "STOPPED": - return response - - print(f"Queue Fleet Association: {qfa_status}\nChecking again...") - retry_count += 1 - sleep(10) - raise ValueError("Timed out waiting for association to reach a STOPPED status.") - - def delete_fleet(self) -> None: - if not hasattr(self, "farm_id") or not self.farm_id: - raise Exception( - "ERROR: Attempting to delete a fleet without having had created a farm!" - ) - - if not hasattr(self, "fleet_id") or not self.fleet_id: - raise Exception("ERROR: Attempting to delete a fleet when none was created!") - - try: - # Delete queue fleet association. - self.stop_queue_fleet_associations_and_wait() - self.deadline_client.delete_queue_fleet_association( - farmId=self.farm_id, queueId=self.queue_id, fleetId=self.fleet_id - ) - # Deleting the fleet. - self.deadline_client.delete_fleet(farmId=self.farm_id, fleetId=self.fleet_id) - except ClientError as e: - print( - f"ERROR: Failed to delete delete fleet with id = {self.fleet_id}", file=sys.stderr - ) - print(f"The following exception was raised: {e}", file=sys.stderr) - raise - else: - print(f"Successfully deleted fleet with id = {self.fleet_id}") - self.fleet_id = None - - def cleanup_scaffolding(self) -> None: - # Only deleting the fleet if we have a fleet. - if hasattr(self, "fleet_id") and self.fleet_id: - self.delete_fleet() - - if hasattr(self, "farm_id") and self.farm_id: - # Only deleting the queue if we have a queue. - if hasattr(self, "queue_id") and self.queue_id: - self.delete_queue() - - self.delete_farm() - - # Only deleting the kms key if we have a kms key. - if hasattr(self, "kms_key_metadata") and self.kms_key_metadata: - self.delete_kms_key() - - def _get_deadline_client(self, deadline_endpoint: Optional[str]) -> DeadlineClient: - """Create a DeadlineClient shim layer over an actual boto client""" - self.session = boto3.Session() - real_deadline_client = self.session.client( - "deadline", - endpoint_url=deadline_endpoint, - ) - - return DeadlineClient(real_deadline_client) - - -class DeadlineClient: - """ - A shim layer for boto Deadline client. This class will check if a method exists on the real - boto3 Deadline client and call it if it exists. If it doesn't exist, an AttributeError will be raised. - """ - - _real_client: Any - - def __init__(self, real_client: Any) -> None: - self._real_client = real_client - - def create_farm(self, *args, **kwargs) -> Any: - create_farm_input_members = self._get_deadline_api_input_shape("CreateFarm") - if "displayName" not in create_farm_input_members and "name" in create_farm_input_members: - kwargs["name"] = kwargs.pop("displayName") - return self._real_client.create_farm(*args, **kwargs) - - def create_fleet(self, *args, **kwargs) -> Any: - create_fleet_input_members = self._get_deadline_api_input_shape("CreateFleet") - if "displayName" not in create_fleet_input_members and "name" in create_fleet_input_members: - kwargs["name"] = kwargs.pop("displayName") - if ( - "roleArn" not in create_fleet_input_members - and "workeRoleArn" in create_fleet_input_members - ): - kwargs["workerRoleArn"] = kwargs.pop("roleArn") - return self._real_client.create_fleet(*args, **kwargs) - - def get_fleet(self, *args, **kwargs) -> Any: - response = self._real_client.get_fleet(*args, **kwargs) - if "name" in response and "displayName" not in response: - response["displayName"] = response["name"] - del response["name"] - if "state" in response and "status" not in response: - response["status"] = response["state"] - del response["state"] - if "type" in response: - del response["type"] - return response - - def get_queue_fleet_association(self, *args, **kwargs) -> Any: - response = self._real_client.get_queue_fleet_association(*args, **kwargs) - if "state" in response and "status" not in response: - response["status"] = response["state"] - del response["state"] - return response - - def create_queue(self, *args, **kwargs) -> Any: - create_queue_input_members = self._get_deadline_api_input_shape("CreateQueue") - if "displayName" not in create_queue_input_members and "name" in create_queue_input_members: - kwargs["name"] = kwargs.pop("displayName") - return self._real_client.create_queue(*args, **kwargs) - - def create_queue_fleet_association(self, *args, **kwargs) -> Any: - create_queue_fleet_association_method_name: Optional[str] - create_queue_fleet_association_method: Optional[str] - - for create_queue_fleet_association_method_name in ( - "put_queue_fleet_association", - "create_queue_fleet_association", - ): - create_queue_fleet_association_method = getattr( - self._real_client, create_queue_fleet_association_method_name, None - ) - if create_queue_fleet_association_method: - break - else: - create_queue_fleet_association_method = None - - # mypy complains about they kwargs type - return create_queue_fleet_association_method(*args, **kwargs) # type: ignore - - def create_job(self, *args, **kwargs) -> Any: - create_job_input_members = self._get_deadline_api_input_shape("CreateJob") - # revert to old parameter names if old service model is used - if "maxRetriesPerTask" in kwargs: - if "maxErrorsPerTask" in create_job_input_members: - kwargs["maxErrorsPerTask"] = kwargs.pop("maxRetriesPerTask") - if "template" in kwargs: - if "jobTemplate" in create_job_input_members: - kwargs["jobTemplate"] = kwargs.pop("template") - kwargs["jobTemplateType"] = kwargs.pop("templateType") - if "parameters" in kwargs: - kwargs["jobParameters"] = kwargs.pop("parameters") - if "targetTaskRunStatus" in kwargs: - if "initialState" in create_job_input_members: - kwargs["initialState"] = kwargs.pop("targetTaskRunStatus") - if "priority" not in kwargs: - kwargs["priority"] = 50 - return self._real_client.create_job(*args, **kwargs) - - def update_queue_fleet_association(self, *args, **kwargs) -> Any: - update_queue_fleet_association_method_name: Optional[str] - update_queue_fleet_association_method: Optional[str] - - for update_queue_fleet_association_method_name in ( - "update_queue_fleet_association", - "update_queue_fleet_association_state", - ): - update_queue_fleet_association_method = getattr( - self._real_client, update_queue_fleet_association_method_name, None - ) - if update_queue_fleet_association_method: - break - else: - update_queue_fleet_association_method = None - - if update_queue_fleet_association_method_name == "update_queue_fleet_association": - # mypy complains about they kwargs type - return update_queue_fleet_association_method(*args, **kwargs) # type: ignore - - if update_queue_fleet_association_method_name == "update_queue_fleet_association_state": - kwargs["state"] = kwargs.pop("status") - # mypy complains about they kwargs type - return update_queue_fleet_association_method(*args, **kwargs) # type: ignore - - def _get_deadline_api_input_shape(self, api_name: str) -> dict[str, Any]: - """ - Given a string name of an API e.g. CreateJob, returns the shape of the - inputs to that API. - """ - api_model = self._get_deadline_api_model(api_name) - if api_model: - return api_model.input_shape.members - return {} - - def _get_deadline_api_model(self, api_name: str) -> Optional[OperationModel]: - """ - Given a string name of an API e.g. CreateJob, returns the OperationModel - for that API from the service model. - """ - data_model_path = os.getenv("AWS_DATA_PATH") - loader = Loader(extra_search_paths=[data_model_path] if data_model_path is not None else []) - deadline_service_description = loader.load_service_model("deadline", "service-2") - deadline_service_model = ServiceModel(deadline_service_description, service_name="deadline") - return OperationModel( - deadline_service_description["operations"][api_name], deadline_service_model - ) - - def __getattr__(self, __name: str) -> Any: - """ - Respond to unknown method calls by calling the underlying _real_client - If the underlying _real_client does not have a given method, an AttributeError - will be raised. - Note that __getattr__ is only called if the attribute cannot otherwise be found, - so if this class alread has the called method defined, __getattr__ will not be called. - This is in opposition to __getattribute__ which is called by default. - """ - - def method(*args, **kwargs): - return getattr(self._real_client, __name)(*args, **kwargs) - - return method diff --git a/src/deadline_test_scaffolding/example_config.sh b/src/deadline_test_scaffolding/example_config.sh new file mode 100644 index 0000000..7a4c488 --- /dev/null +++ b/src/deadline_test_scaffolding/example_config.sh @@ -0,0 +1,147 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# This is an example configuration for the deadline-cloud-test-fixtures package + + +# If "true", the Docker worker is used instead of the EC2 worker. Default is to use the EC2 worker. +# For EC2 Worker configuration, see the "EC2 WORKER OPTIONS" section below +# For Docker Worker configuration, see the "DOCKER WORKER OPTIONS" section below +export USE_DOCKER_WORKER + + +# ====================== # +# === COMMON OPTIONS === # +# ====================== # + +# --- REQUIRED --- # + +# The AWS account ID to deploy infrastructure into +export SERVICE_ACCOUNT_ID + +# CodeArtifact repository information to configure pip to pull Python dependencies from +# +# The domain owner AWS account ID +export CODEARTIFACT_ACCOUNT_ID +# The domain the repository is in +export CODEARTIFACT_DOMAIN +# The name of the repository +export CODEARTIFACT_REPOSITORY +# The region the repository is in +export CODEARTIFACT_REGION + +# --- OPTIONAL --- # + +# Extra local path for boto to look for AWS models in +# Does not apply to the worker +export AWS_DATA_PATH + +# Local path to the Worker agent .whl file to use for the tests +# Default is to pip install the latest "deadline-cloud-worker-agent" package +export WORKER_AGENT_WHL_PATH + +# The AWS region to configure the worker for +# Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2 +export WORKER_REGION + +# The POSIX user to configure the worker for +# Defaults to "deadline-worker" +export WORKER_POSIX_USER + +# The shared POSIX group to configure the worker user and job user with +# Defaults to "shared-group" +export WORKER_POSIX_SHARED_GROUP + +# PEP 508 requirement specifier for the Worker agent package +# If WORKER_AGENT_WHL_PATH is provided, this option is ignored +export WORKER_AGENT_REQUIREMENT_SPECIFIER + +# The S3 URI for the "deadline" service model to use for the tests +# Falls back to LOCAL_MODEL_PATH, then defaults to your locally installed service model +export DEADLINE_SERVICE_MODEL_S3_URI + +# Path to a local Deadline model file to use for API calls +# If DEADLINE_SERVICE_MODEL_S3_URI is provided, this option is ignored +# Default is to use the locally installed service model on your machine +export LOCAL_MODEL_PATH + +# The endpoint to use for requests to the Amazon Deadline Cloud service +# Default is the endpoint specified in your AWS model file for "deadline" +export DEADLINE_ENDPOINT + +# The CredentialVending service principal to configure the Worker IAM roles with +# If you don't know what this is, then you probably don't need to provide this +export CREDENTIAL_VENDING_PRINCIPAL + +# Used as an infix for the S3 bucket deployed by the "deploy_job_attachment_resources" fixture +# Defaults to "dev" +export STAGE + +# If set to "true", does not stop the worker after test failure. Useful for debugging. +export KEEP_WORKER_AFTER_FAILURE + + + +# If BYO_DEADLINE is "true", uses existing Deadline resource IDs as specified below +# By default, new resources are deployed for you that get deleted after test runs +export BYO_DEADLINE +# Required - The ID of the farm to use +export FARM_ID +# Required - The ID of the queue to use +export QUEUE_ID +# Required - The ID of the fleet to use +export FLEET_ID +# Optional - The ID of the KMS key association with your farm +# If you use this option, then you must BYO_BOOTSTRAP because the default IAM role created for +# the Worker will not have sufficient permissions to access this key +export FARM_KMS_KEY_ID +# Optional - The name of the S3 buckets to use for Job Attachments +export JOB_ATTACHMENTS_BUCKET + + + +# If BYO_BOOTSTRAP is "true", uses existing bootstrap resources as specified below +# By default, new resources are deployed for you in a CloudFormation stack. +# This stack is not destroyed automatically after test runs. +export BYO_BOOTSTRAP +# Required - The name of the S3 bucket to use for bootstrapping files +export BOOTSTRAP_BUCKET_NAME +# Required - ARN of the IAM role to use for the Worker +export WORKER_ROLE_ARN +# Optional - ARN of the IAM role to use for sessions running on the Worker +export SESSION_ROLE_ARN +# Optional - Name of the IAM instance profile to bootstrap the Worker instance with +# This option does not apply if you USE_DOCKER_WORKER +export WORKER_INSTANCE_PROFILE_NAME + + + +# ========================== # +# === EC2 WORKER OPTIONS === # +# ========================== # + +# --- REQUIRED --- # + +# Subnet to deploy the EC2 instance into +export SUBNET_ID + +# Security group to deploy the EC2 instance into +export SECURITY_GROUP_ID + +# --- OPTIONAL --- # + +# AMI ID to use for the EC2 instance +# Defaults to latest AL2023 AMI +export AMI_ID + + + +# ============================= # +# === DOCKER WORKER OPTIONS === # +# ============================= # + +# --- REQUIRED --- # + +# None yet + +# --- OPTIONAL --- # + +# None yet diff --git a/src/deadline_test_scaffolding/fixtures.py b/src/deadline_test_scaffolding/fixtures.py index ccf7378..5133576 100644 --- a/src/deadline_test_scaffolding/fixtures.py +++ b/src/deadline_test_scaffolding/fixtures.py @@ -1,335 +1,476 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from __future__ import annotations import botocore +import botocore.client +import botocore.loaders import boto3 +import glob +import logging import os -import time +import posixpath import pytest -import json -from typing import Any, Callable, Generator, Dict, Optional, Type -from types import TracebackType - -from .deadline_manager import DeadlineManager -from .job_attachment_manager import JobAttachmentManager -from .utils import ( - generate_worker_role_cfn_template, - generate_boostrap_worker_role_cfn_template, - generate_boostrap_instance_profile_cfn_template, - generate_queue_session_role, - generate_job_attachments_bucket, - generate_job_attachments_bucket_policy, +from dataclasses import InitVar, dataclass, field, fields, MISSING +from typing import Any, Generator + +from .deadline.client import DeadlineClient +from .deadline.resources import ( + Farm, + Fleet, + Queue, + QueueFleetAssociation, ) - -from .constants import ( - DEADLINE_WORKER_ROLE, - DEADLINE_WORKER_BOOTSTRAP_ROLE, - DEADLINE_WORKER_BOOSTRAP_INSTANCE_PROFILE_NAME, - DEADLINE_QUEUE_SESSION_ROLE, - DEADLINE_SERVICE_MODEL_BUCKET, - CODEARTIFACT_DOMAIN, - CODEARTIFACT_ACCOUNT_ID, - CODEARTIFACT_REPOSITORY, - JOB_ATTACHMENTS_BUCKET_NAME, - JOB_ATTACHMENTS_BUCKET_RESOURCE, - JOB_ATTACHMENTS_BUCKET_POLICY_RESOURCE, - BOOTSTRAP_CLOUDFORMATION_STACK_NAME, - STAGE, +from .deadline.worker import ( + DeadlineWorker, + DeadlineWorkerConfiguration, + DockerContainerWorker, + EC2InstanceWorker, + PipInstall, ) +from .models import CodeArtifactRepositoryInfo, JobAttachmentSettings, ServiceModel, S3Object +from .cloudformation import WorkerBootstrapStack +from .job_attachment_manager import JobAttachmentManager -AMI_ID = os.environ.get("AMI_ID", "") -SUBNET_ID = os.environ.get("SUBNET_ID", "") -SECURITY_GROUP_ID = os.environ.get("SECURITY_GROUP_ID", "") - - -@pytest.fixture(scope="session") -def stage() -> str: - if os.getenv("LOCAL_DEVELOPMENT", "false").lower() == "true": - return "dev" - else: - return os.environ["STAGE"] - +LOG = logging.getLogger(__name__) -@pytest.fixture(scope="session") -def account_id() -> str: - return os.environ["SERVICE_ACCOUNT_ID"] +@dataclass(frozen=True) +class BootstrapResources: + bootstrap_bucket_name: str + worker_role_arn: str + session_role_arn: str | None = None + worker_instance_profile_name: str | None = None -# Boto client fixtures -@pytest.fixture(scope="session") -def session() -> boto3.Session: - return boto3.Session() + job_attachments: JobAttachmentSettings | None = field(init=False, default=None) + job_attachments_bucket_name: InitVar[str | None] = None + job_attachments_root_prefix: InitVar[str | None] = None + def __post_init__( + self, + job_attachments_bucket_name: str | None, + job_attachments_root_prefix: str | None, + ) -> None: + if job_attachments_bucket_name or job_attachments_root_prefix: + assert ( + job_attachments_bucket_name and job_attachments_root_prefix + ), "Cannot provide partial Job Attachments settings, both bucket name and root prefix are required" + object.__setattr__( + self, + "job_attachments", + JobAttachmentSettings( + bucket_name=job_attachments_bucket_name, + root_prefix=job_attachments_root_prefix, + ), + ) -@pytest.fixture(scope="session") -def iam_client(session: boto3.Session) -> botocore.client.BaseClient: - return session.client("iam") +@dataclass(frozen=True) +class DeadlineResources: + farm: Farm = field(init=False) + queue: Queue = field(init=False) + fleet: Fleet = field(init=False) -@pytest.fixture(scope="session") -def ec2_client(session: boto3.Session) -> botocore.client.BaseClient: - return session.client("ec2") + farm_id: InitVar[str] + queue_id: InitVar[str] + fleet_id: InitVar[str] + farm_kms_key_id: str | None = None + job_attachments_bucket: str | None = None -@pytest.fixture(scope="session") -def ssm_client(session: boto3.Session) -> botocore.client.BaseClient: - return session.client("ssm") + def __post_init__( + self, + farm_id: str, + queue_id: str, + fleet_id: str, + ) -> None: + object.__setattr__(self, "farm", Farm(id=farm_id)) + object.__setattr__(self, "queue", Queue(id=queue_id, farm=self.farm)) + object.__setattr__(self, "fleet", Fleet(id=fleet_id, farm=self.farm)) @pytest.fixture(scope="session") -def cfn_client(session: boto3.Session) -> botocore.client.BaseClient: - return session.client("cloudformation") +def deadline_client() -> DeadlineClient: + endpoint_url = os.getenv("DEADLINE_ENDPOINT") + if endpoint_url: + LOG.info(f"Using Amazon Deadline Cloud endpoint: {endpoint_url}") - -# Bootstrap persistent resources -@pytest.fixture( - scope="session", autouse=os.environ.get("SKIP_BOOTSTRAP_TEST_RESOURCES", "False") != "True" -) -def bootstrap_test_resources(cfn_client: botocore.client.BaseClient) -> None: - # All required resources are created using CloudFormation stack - cfn_template: dict[str, Any] = { - "AWSTemplateFormatVersion": "2010-09-09", - "Description": "Stack created by deadline-cloud-test-fixtures", - "Resources": { - # A role for use by the Worker Agent after being bootstrapped - DEADLINE_WORKER_ROLE: generate_worker_role_cfn_template(), - DEADLINE_WORKER_BOOTSTRAP_ROLE: generate_boostrap_worker_role_cfn_template(), - DEADLINE_QUEUE_SESSION_ROLE: generate_queue_session_role(), - DEADLINE_WORKER_BOOSTRAP_INSTANCE_PROFILE_NAME: generate_boostrap_instance_profile_cfn_template(), - JOB_ATTACHMENTS_BUCKET_RESOURCE: generate_job_attachments_bucket(), - JOB_ATTACHMENTS_BUCKET_POLICY_RESOURCE: generate_job_attachments_bucket_policy(), - }, - } - stack_name = BOOTSTRAP_CLOUDFORMATION_STACK_NAME - update_or_create_cfn_stack(cfn_client, stack_name, cfn_template) - - -# create or update bootstrap -def update_or_create_cfn_stack( - cfn_client: botocore.client.BaseClient, stack_name: str, cfn_template: Dict[str, Any] -) -> None: - try: - cfn_client.update_stack( - StackName=stack_name, - TemplateBody=json.dumps(cfn_template), - Capabilities=["CAPABILITY_NAMED_IAM"], - ) - waiter = cfn_client.get_waiter("stack_update_complete") - waiter.wait(StackName=stack_name) - except cfn_client.exceptions.ClientError as e: - if e.response["Error"]["Message"] != "No updates are to be performed.": - cfn_client.create_stack( - StackName=stack_name, - TemplateBody=json.dumps(cfn_template), - Capabilities=["CAPABILITY_NAMED_IAM"], - OnFailure="DELETE", - EnableTerminationProtection=False, - ) - waiter = cfn_client.get_waiter("stack_create_complete") - waiter.wait(StackName=stack_name) + return DeadlineClient(boto3.client("deadline", endpoint_url=endpoint_url)) @pytest.fixture(scope="session") -def deadline_manager_fixture(): - deadline_manager_fixture = DeadlineManager(should_add_deadline_models=True) - yield deadline_manager_fixture +def codeartifact() -> CodeArtifactRepositoryInfo: + """ + Gets the information for the CodeArtifact repository to use for Python dependencies. + + Environment Variables: + CODEARTIFACT_REGION: The region the CodeArtifact repository is in + CODEARTIFACT_DOMAIN: The domain of the CodeArtifact repository + CODEARTIFACT_ACCOUNT_ID: The AWS account ID which owns the domain + CODEARTIFACT_REPOSITORY: The name of the CodeArtifact repository + + Returns: + CodeArtifactRepositoryInfo: Info about the CodeArtifact repository + """ + return CodeArtifactRepositoryInfo( + region=os.environ["CODEARTIFACT_REGION"], + domain=os.environ["CODEARTIFACT_DOMAIN"], + domain_owner=os.environ["CODEARTIFACT_ACCOUNT_ID"], + repository=os.environ["CODEARTIFACT_REPOSITORY"], + ) -# get the worker role arn @pytest.fixture(scope="session") -def worker_role_arn(iam_client: botocore.client.BaseClient) -> str: - response = iam_client.get_role(RoleName=DEADLINE_WORKER_ROLE) - return response["Role"]["Arn"] +def service_model_s3_object() -> S3Object | None: + service_model_s3_uri = os.getenv("DEADLINE_SERVICE_MODEL_S3_URI") + return S3Object.from_uri(service_model_s3_uri) if service_model_s3_uri else None @pytest.fixture(scope="session") -def deadline_scaffolding( - deadline_manager_fixture: DeadlineManager, worker_role_arn: str -) -> Generator[Any, None, None]: - deadline_manager_fixture.create_scaffolding(worker_role_arn, JOB_ATTACHMENTS_BUCKET_NAME) - - yield deadline_manager_fixture - - deadline_manager_fixture.cleanup_scaffolding() - +def bootstrap_resources(request: pytest.FixtureRequest) -> BootstrapResources: + """ + Gets Bootstrap resources required for running tests. + + Environment Variables: + SERVICE_ACCOUNT_ID: ID of the AWS account to deploy the bootstrap stack into. + This option is ignored if BYO_BOOTSTRAP is set to "true" + CREDENTIAL_VENDING_PRINCIPAL: The credential vending service principal to use. + Defaults to credential-vending.deadline-closed-beta.amazonaws.com + This option is ignored if BYO_BOOTSTRAP is set to "true" + BYO_BOOTSTRAP: Whether the bootstrap stack deployment should be skipped. + If this is set to "true", environment values must be specified to fill the resources. + : Corresponds to an field in the BootstrapResources class with an uppercase name. + e.g. WORKER_ROLE_ARN -> BootstrapResources.worker_role_arn + + Returns: + BootstrapResources: The bootstrap resources used for tests + """ + + if os.environ.get("BYO_BOOTSTRAP", "false").lower() == "true": + kwargs: dict[str, Any] = {} + + all_fields = fields(BootstrapResources) + for f in all_fields: + env_var = f.name.upper() + if env_var in os.environ: + kwargs[f.name] = os.environ[env_var] + + required_fields = [f for f in all_fields if (MISSING == f.default == f.default_factory)] + assert all([rf.name in kwargs for rf in required_fields]), ( + "Not all bootstrap resources have been fulfilled via environment variables. Expected " + + f"values for {[f.name.upper() for f in required_fields]}, but got {kwargs}" + ) + LOG.info( + f"All bootstrap resources have been fulfilled via environment variables. Using {kwargs}" + ) + return BootstrapResources(**kwargs) + else: + account = os.environ["SERVICE_ACCOUNT_ID"] + codeartifact: CodeArtifactRepositoryInfo = request.getfixturevalue("codeartifact") + service_model_s3_object: S3Object | None = request.getfixturevalue( + "service_model_s3_object" + ) + crednetial_vending_service_principal = os.getenv( + "CREDENTIAL_VENDING_PRINCIPAL", + "credential-vending.deadline-closed-beta.amazonaws.com", + ) -@pytest.fixture(scope="session") -def launch_instance(ec2_client: botocore.client.BaseClient) -> Generator[Any, None, None]: - with _InstanceLauncher( - ec2_client, - AMI_ID, - SUBNET_ID, - SECURITY_GROUP_ID, - DEADLINE_WORKER_BOOSTRAP_INSTANCE_PROFILE_NAME, - ) as instance_id: - yield instance_id + stack_name = "DeadlineScaffoldingWorkerBootstrapStack" + LOG.info(f"Deploying bootstrap stack {stack_name}") + stack = WorkerBootstrapStack( + name=stack_name, + codeartifact=codeartifact, + account=account, + credential_vending_service_principal=crednetial_vending_service_principal, + service_model_s3_object_arn=service_model_s3_object.arn + if service_model_s3_object + else None, + ) + stack.deploy(cfn_client=boto3.client("cloudformation")) + + return BootstrapResources( + bootstrap_bucket_name=stack.bootstrap_bucket.physical_name, + worker_role_arn=stack.worker_role.format_arn(account=account), + session_role_arn=stack.session_role.format_arn(account=account), + worker_instance_profile_name=stack.worker_instance_profile.physical_name, + job_attachments_bucket_name=stack.job_attachments_bucket.physical_name, + job_attachments_root_prefix="root", + ) @pytest.fixture(scope="session") -def create_worker_agent( - deadline_scaffolding, launch_instance: str, send_ssm_command: Callable -) -> Generator[Any, None, None]: - def configure_worker_agent_func() -> Dict: - """Creates a Deadline Farm, starts an instance and configures and starts a Worker Agent.""" - assert deadline_scaffolding - assert launch_instance - - configuration_command_response = send_ssm_command( - launch_instance, - ( - f"adduser -r -m agentuser && \n" - f"adduser -r -m jobuser && \n" - f"usermod -a -G jobuser agentuser && \n" - f"chmod 770 /home/jobuser && \n" - f"touch /etc/sudoers.d/deadline-worker-job-user && \n" - f'echo "agentuser ALL=(jobuser) NOPASSWD:ALL" | sudo tee /etc/sudoers.d/deadline-worker-job-user && \n' - f"python3.9 -m venv /opt/deadline/worker && \n" - f"source /opt/deadline/worker/bin/activate && \n" - f"pip install --upgrade pip && \n" - f"touch /opt/deadline/worker/pip.conf && \n" - # TODO: Remove when pypi is available - f"aws codeartifact login --tool pip --domain {CODEARTIFACT_DOMAIN} --domain-owner {CODEARTIFACT_ACCOUNT_ID} --repository {CODEARTIFACT_REPOSITORY} && \n" - f"aws s3 cp s3://{DEADLINE_SERVICE_MODEL_BUCKET}/service-2.json /tmp/deadline-beta-2020-08-21.json && \n" - f"chmod +r /tmp/deadline-beta-2020-08-21.json && \n" - f"sudo -u agentuser aws configure add-model --service-model file:///tmp/deadline-beta-2020-08-21.json --service-name deadline && \n" - f"mkdir /var/lib/deadline /var/log/amazon/deadline/ && \n" - f"chown agentuser:agentuser /var/lib/deadline /var/log/amazon/deadline/ && \n" - f"pip install deadline-worker-agent && \n" - f"sudo -u agentuser /opt/deadline/worker/bin/deadline_worker_agent --help" - ), +def deadline_resources( + request: pytest.FixtureRequest, deadline_client: DeadlineClient +) -> Generator[DeadlineResources, None, None]: + """ + Gets Deadline resources required for running tests. + + Environment Variables: + BYO_DEADLINE: Whether the Deadline resource deployment should be skipped. + If this is set to "true", environment values must be specified to fill the resources. + : Corresponds to an field in the DeadlineResources class with an uppercase name. + e.g. FARM_ID -> DeadlineResources.farm_id + + Returns: + DeadlineResources: The Deadline resources used for tests + """ + if os.getenv("BYO_DEADLINE", "false").lower() == "true": + kwargs: dict[str, Any] = {} + + all_fields = fields(DeadlineResources) + for f in all_fields: + env_var = f.name.upper() + if env_var in os.environ: + kwargs[f.name] = os.environ[env_var] + + required_fields = [f for f in all_fields if (MISSING == f.default == f.default_factory)] + assert all([rf.name in kwargs for rf in required_fields]), ( + "Not all Deadline resources have been fulfilled via environment variables. Expected " + + f"values for {[f.name.upper() for f in required_fields]}, but got {kwargs}" ) - - return configuration_command_response - - def start_worker_agent_func() -> Dict: - start_command_response = send_ssm_command( - launch_instance, - ( - f"nohup sudo -E AWS_DEFAULT_REGION=us-west-2 -u agentuser /opt/deadline/worker/bin/deadline_worker_agent --farm-id {deadline_scaffolding.farm_id} --fleet-id {deadline_scaffolding.fleet_id} --allow-instance-profile >/dev/null 2>&1 &" - ), + LOG.info( + f"All Deadline resources have been fulfilled via environment variables. Using {kwargs}" + ) + yield DeadlineResources(**kwargs) + else: + LOG.info("Deploying Deadline resources") + bootstrap_resources: BootstrapResources = request.getfixturevalue("bootstrap_resources") + farm = Farm.create( + client=deadline_client, + display_name="test-scaffolding-farm", + ) + queue = Queue.create( + client=deadline_client, + display_name="test-scaffolding-queue", + farm=farm, + job_attachments=bootstrap_resources.job_attachments, + role_arn=bootstrap_resources.session_role_arn, + ) + fleet = Fleet.create( + client=deadline_client, + display_name="test-scaffolding-fleet", + farm=farm, + configuration={ + "customerManaged": { + "autoScalingConfiguration": { + "mode": "NO_SCALING", + "maxFleetSize": 1, + }, + "workerRequirements": { + "vCpuCount": {"min": 1}, + "memoryMiB": {"min": 1024}, + "osFamily": "linux", + "cpuArchitectureType": "x86_64", + }, + }, + }, + role_arn=bootstrap_resources.worker_role_arn, + ) + qfa = QueueFleetAssociation.create( + client=deadline_client, + farm=farm, + queue=queue, + fleet=fleet, ) - return start_command_response - - configuration_result = configure_worker_agent_func() - - assert configuration_result["ResponseCode"] == 0 - - run_worker = start_worker_agent_func() - - assert run_worker["ResponseCode"] == 0 + yield DeadlineResources( + farm_id=farm.id, + queue_id=queue.id, + fleet_id=fleet.id, + job_attachments_bucket=bootstrap_resources.job_attachments.bucket_name + if bootstrap_resources.job_attachments + else None, + ) - yield run_worker + qfa.delete(client=deadline_client) + fleet.delete(client=deadline_client) + queue.delete(client=deadline_client) + farm.delete(client=deadline_client) @pytest.fixture(scope="session") -def send_ssm_command(ssm_client: botocore.client.BaseClient) -> Callable: - def send_ssm_command_func(instance_id: str, command: str) -> Dict: - """Helper function to send single commands via SSM to a shell on a launched EC2 instance. Once the command has fully - finished the result of the invocation is returned. - """ - ssm_waiter = ssm_client.get_waiter("command_executed") - - # To successfully send an SSM Command to an instance the instance must: - # 1) Be in RUNNING state; - # 2) Have the AWS Systems Manager (SSM) Agent running; and - # 3) Have had enough time for the SSM Agent to connect to System's Manager - # - # If we send an SSM command then we will get an InvalidInstanceId error - # if the instance isn't in that state. - NUM_RETRIES = 10 - SLEEP_INTERVAL_S = 5 - for i in range(0, NUM_RETRIES): - try: - send_command_response = ssm_client.send_command( - InstanceIds=[instance_id], - DocumentName="AWS-RunShellScript", - Parameters={"commands": [command]}, - ) - # Successfully sent. Bail out of the loop. - break - except botocore.exceptions.ClientError as error: - error_code = error.response["Error"]["Code"] - if error_code == "InvalidInstanceId" and i < NUM_RETRIES - 1: - time.sleep(SLEEP_INTERVAL_S) - continue - raise - - command_id = send_command_response["Command"]["CommandId"] - - ssm_waiter.wait(InstanceId=instance_id, CommandId=command_id) - ssm_command_result = ssm_client.get_command_invocation( - InstanceId=instance_id, CommandId=command_id +def worker( + request: pytest.FixtureRequest, + deadline_client: DeadlineClient, + deadline_resources: DeadlineResources, + codeartifact: CodeArtifactRepositoryInfo, + service_model_s3_object: S3Object | None, +) -> Generator[DeadlineWorker, None, None]: + """ + Gets a DeadlineWorker for use in tests. + + Environment Variables: + SUBNET_ID: The subnet ID to deploy the EC2 worker into. + This is required for EC2 workers. Does not apply if USE_DOCKER_WORKER is true. + SECURITY_GROUP_ID: The security group ID to deploy the EC2 worker into. + This is required for EC2 workers. Does not apply if USE_DOCKER_WORKER is true. + AMI_ID: The AMI ID to use for the Worker agent. + Defaults to the latest AL2023 AMI. + Does not apply if USE_DOCKER_WORKER is true. + WORKER_REGION: The AWS region to configure the worker for + Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2 + WORKER_POSIX_USER: The POSIX user to configure the worker for + Defaults to "deadline-worker" + WORKER_POSIX_SHARED_GROUP: The shared POSIX group to configure the worker user and job user with + Defaults to "shared-group" + WORKER_AGENT_WHL_PATH: Path to the Worker agent wheel file to use. + WORKER_AGENT_REQUIREMENT_SPECIFIER: PEP 508 requirement specifier for the Worker agent package. + If WORKER_AGENT_WHL_PATH is provided, this option is ignored. + LOCAL_MODEL_PATH: Path to a local Deadline model file to use for API calls. + If DEADLINE_SERVICE_MODEL_S3_URI was provided, this option is ignored. + USE_DOCKER_WORKER: If set to "true", this fixture will create a Worker that runs in a local Docker container instead of an EC2 instance. + + Returns: + DeadlineWorker: Instance of the DeadlineWorker class that can be used to interact with the Worker. + """ + file_mappings: list[tuple[str, str]] = [] + + # Prepare the Worker agent Python package + worker_agent_whl_path = os.getenv("WORKER_AGENT_WHL_PATH") + if worker_agent_whl_path: + LOG.info(f"Using Worker agent whl file: {worker_agent_whl_path}") + resolved_whl_paths = glob.glob(worker_agent_whl_path) + assert ( + len(resolved_whl_paths) == 1 + ), f"Expected exactly one Worker agent whl path, but got {resolved_whl_paths} (from pattern {worker_agent_whl_path})" + resolved_whl_path = resolved_whl_paths[0] + + dest_path = posixpath.join("/tmp", os.path.basename(resolved_whl_path)) + file_mappings = [(resolved_whl_path, dest_path)] + + LOG.info(f"The whl file will be copied to {dest_path} on the Worker environment") + worker_agent_requirement_specifier = dest_path + else: + worker_agent_requirement_specifier = os.getenv( + "WORKER_AGENT_REQUIREMENT_SPECIFIER", + "deadline-cloud-worker-agent", + ) + LOG.info(f"Using Worker agent package {worker_agent_requirement_specifier}") + + # Prepare the service model + service_model: ServiceModel + if service_model_s3_object: + LOG.info(f"Using Deadline model from S3: {service_model_s3_object.uri}") + service_model = ServiceModel.from_s3( + local_filename=posixpath.join("/tmp", "deadline-cloud-service-model.json"), + object=service_model_s3_object, + service_name="deadline", + ) + else: + local_model_path = os.getenv("LOCAL_MODEL_PATH") + if local_model_path: + LOG.info( + f"Using Deadline model from local path provided via env var: {local_model_path}" + ) + else: + local_model_path = _find_latest_service_model_file("deadline") + LOG.info(f"Using Deadline model installed at: {local_model_path}") + dst_path = posixpath.join("/tmp", "deadline-cloud-service-model.json") + service_model = ServiceModel.from_local_file( + local_file_path=dst_path, service_name="deadline" + ) + file_mappings.append((local_model_path, dst_path)) + + configuration = DeadlineWorkerConfiguration( + farm_id=deadline_resources.farm.id, + fleet_id=deadline_resources.fleet.id, + region=os.getenv("WORKER_REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2")), + user=os.getenv("WORKER_POSIX_USER", "deadline-worker"), + group=os.getenv("WORKER_POSIX_SHARED_GROUP", "shared-group"), + allow_shutdown=True, + worker_agent_install=PipInstall( + requirement_specifiers=[worker_agent_requirement_specifier], + codeartifact=codeartifact, + ), + service_model=service_model, + file_mappings=file_mappings or None, + ) + + worker: DeadlineWorker + if os.environ.get("USE_DOCKER_WORKER", False): + LOG.info("Creating Docker worker") + worker = DockerContainerWorker( + configuration=configuration, + ) + else: + LOG.info("Creating EC2 worker") + ami_id = os.getenv("AMI_ID") + subnet_id = os.getenv("SUBNET_ID") + security_group_id = os.getenv("SECURITY_GROUP_ID") + assert subnet_id, "SUBNET_ID is required when deploying an EC2 worker" + assert security_group_id, "SECURITY_GROUP_ID is required when deploying an EC2 worker" + + bootstrap_resources: BootstrapResources = request.getfixturevalue("bootstrap_resources") + assert ( + bootstrap_resources.worker_instance_profile_name + ), "Worker instance profile is required when deploying an EC2 worker" + + ec2_client = boto3.client("ec2") + s3_client = boto3.client("s3") + ssm_client = boto3.client("ssm") + + worker = EC2InstanceWorker( + ec2_client=ec2_client, + deadline_client=deadline_client, + s3_client=s3_client, + bootstrap_bucket_name=bootstrap_resources.bootstrap_bucket_name, + ssm_client=ssm_client, + override_ami_id=ami_id, + subnet_id=subnet_id, + security_group_id=security_group_id, + instance_profile_name=bootstrap_resources.worker_instance_profile_name, + configuration=configuration, ) - return ssm_command_result + def stop_worker(): + try: + worker.stop() + except Exception as e: + LOG.exception(f"Error while stopping worker: {e}") + LOG.error( + "Failed to stop worker. Resources may be left over that need to be cleaned up manually." + ) + raise - return send_ssm_command_func + try: + worker.start() + except Exception as e: + LOG.exception(f"Failed to start worker: {e}") + if os.getenv("KEEP_WORKER_AFTER_FAILURE", "false").lower() != "true": + LOG.info("Stopping worker because it failed to start") + stop_worker() + raise + yield worker -@pytest.fixture(scope="session") -def job_attachment_manager_fixture(stage: str, account_id: str): - job_attachment_manager = JobAttachmentManager(stage, account_id) - yield job_attachment_manager + stop_worker() @pytest.fixture(scope="session") -def deploy_job_attachment_resources(job_attachment_manager_fixture: JobAttachmentManager): - job_attachment_manager_fixture.deploy_resources() - yield job_attachment_manager_fixture - job_attachment_manager_fixture.cleanup_resources() - - -class _InstanceLauncher: - ami_id: str - subnet_id: str - security_group_id: str - instance_profile_name: str - instance_id: str - ec2_client: botocore.client.BaseClient - - def __init__( - self, - ec2_client: botocore.client.BaseClient, - ami_id: str, - subnet_id: str, - security_group_id: str, - instance_profile_name: str, - ) -> None: - self.ec2_client = ec2_client - self.ami_id = ami_id - self.subnet_id = subnet_id - self.security_group_id = security_group_id - self.instance_profile_name = instance_profile_name - - def __enter__(self) -> str: - instance_running_waiter = self.ec2_client.get_waiter("instance_status_ok") - - run_instance_response = self.ec2_client.run_instances( - MinCount=1, - MaxCount=1, - ImageId=self.ami_id, - InstanceType="t3.micro", - IamInstanceProfile={"Name": self.instance_profile_name}, - SubnetId=self.subnet_id, - SecurityGroupIds=[self.security_group_id], - MetadataOptions={"HttpTokens": "required", "HttpEndpoint": "enabled"}, - TagSpecifications=[ - { - "ResourceType": "instance", - "Tags": [{"Key": "InstanceIdentification", "Value": f"TestScaffolding{STAGE}"}], - } - ], - ) - - self.instance_id = run_instance_response["Instances"][0]["InstanceId"] - - instance_running_waiter.wait(InstanceIds=[self.instance_id]) - return self.instance_id - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - self.ec2_client.terminate_instances(InstanceIds=[self.instance_id]) +def deploy_job_attachment_resources() -> Generator[JobAttachmentManager, None, None]: + """ + Deploys Job Attachments resources for integration tests + + Environment Variables: + SERVICE_ACCOUNT_ID: The account ID the resources will be deployed to + STAGE: The stage these resources are being deployed to + Defaults to "dev" + + Returns: + JobAttachmentManager: Class to manage Job Attachments resources + """ + manager = JobAttachmentManager( + account_id=os.environ["SERVICE_ACCOUNT_ID"], + stage=os.getenv("STAGE", "dev"), + ) + manager.deploy_resources() + yield manager + manager.cleanup_resources() + + +def _find_latest_service_model_file(service_name: str) -> str: + loader = botocore.loaders.Loader(include_default_search_paths=True) + full_name = os.path.join( + service_name, loader.determine_latest_version(service_name, "service-2"), "service-2" + ) + _, service_model_path = loader.load_data_with_path(full_name) + return f"{service_model_path}.json" diff --git a/src/deadline_test_scaffolding/job_attachment_manager.py b/src/deadline_test_scaffolding/job_attachment_manager.py index b0757c4..758b3c8 100644 --- a/src/deadline_test_scaffolding/job_attachment_manager.py +++ b/src/deadline_test_scaffolding/job_attachment_manager.py @@ -1,11 +1,18 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from __future__ import annotations -import pathlib +from typing import Any import boto3 +from botocore.client import BaseClient from botocore.exceptions import ClientError, WaiterError -from deadline_test_scaffolding.deadline_manager import DeadlineManager +from .cloudformation import JobAttachmentsBootstrapStack +from .deadline.client import DeadlineClient +from .deadline import ( + Farm, + Queue, +) class JobAttachmentManager: @@ -13,91 +20,46 @@ class JobAttachmentManager: Responsible for setting up and tearing down job attachment test resources """ - RESOURCE_CF_TEMPLATE_LOCATION = pathlib.Path( - pathlib.Path(__file__).parent / "cf_templates" / "job_attachments.yaml" - ) + cfn_client: BaseClient + deadline_client: DeadlineClient + bucket: Any + farm: Farm | None + queue: Queue | None + stack: JobAttachmentsBootstrapStack def __init__(self, stage: str, account_id: str): - cloudformation = boto3.resource("cloudformation") s3 = boto3.resource("s3") - self.stack = cloudformation.Stack("JobAttachmentIntegTest") - self.deadline_manager = DeadlineManager(should_add_deadline_models=True) + self.cfn_client = boto3.client("cloudformation") + self.deadline_client = DeadlineClient(boto3.client("deadline")) + self.bucket = s3.Bucket(f"job-attachment-integ-test-{stage.lower()}-{account_id}") + self.stack = JobAttachmentsBootstrapStack( + name="JobAttachmentIntegTest", + bucket_name=self.bucket.name, + ) + self.farm = None + self.queue = None def deploy_resources(self): """ Deploy all of the resources needed for job attachment integration tests. """ try: - self.deadline_manager.create_kms_key() - self.deadline_manager.create_farm("job_attachments_test_farm") - self.deadline_manager.create_queue("job_attachments_test_queue") - self.deploy_stack() + self.farm = Farm.create( + client=self.deadline_client, + display_name="job_attachments_test_farm", + ) + self.queue = Queue.create( + client=self.deadline_client, + display_name="job_attachments_test_queue", + farm=self.farm, + ) + self.stack.deploy(cfn_client=self.cfn_client) except (ClientError, WaiterError): # If anything goes wrong, rollback self.cleanup_resources() raise - def _create_stack(self, template_body: str): - try: - # The stack resource doesn't have an action for creating the stack, - # only updating it. So we need to go through the client. - self.stack.meta.client.create_stack( - StackName=self.stack.name, - TemplateBody=template_body, - OnFailure="DELETE", - EnableTerminationProtection=False, - Parameters=[ - { - "ParameterKey": "BucketName", - "ParameterValue": self.bucket.name, - }, - ], - ) - except ClientError as e: - # Sometimes the cloudformation create stack waiter will release even if if the stack - # isn't in create_complete. So we have to catch that here and move on. - if e.response["Error"]["Message"] != f"Stack [{self.stack.name}] already exists": - raise - - waiter = self.stack.meta.client.get_waiter("stack_create_complete") - waiter.wait( - StackName=self.stack.name, - ) - - def deploy_stack(self): - """ - Deploy the job attachment test stack to the test account. If the stack already exists then - update it, if the stack doesn't exist then create it. - - Keep the stack around between tests to reduce further test times. - """ - with open(self.RESOURCE_CF_TEMPLATE_LOCATION) as f: - template_body = f.read() - - try: - self.stack.update( - TemplateBody=template_body, - Parameters=[ - { - "ParameterKey": "BucketName", - "ParameterValue": self.bucket.name, - }, - ], - ) - waiter = self.stack.meta.client.get_waiter("stack_update_complete") - waiter.wait(StackName=self.stack.name) - except ClientError as e: - if ( - "is in CREATE_IN_PROGRESS state and can not be updated." - in e.response["Error"]["Message"] - ): - waiter = self.stack.meta.client.get_waiter("stack_create_complete") - waiter.wait(StackName=self.stack.name) - - elif e.response["Error"]["Message"] != "No updates are to be performed.": - self._create_stack(template_body) - def empty_bucket(self): """ Empty the bucket between session runs @@ -112,11 +74,9 @@ def cleanup_resources(self): """ Cleanup all of the resources that the test used, except for the stack. """ - self.deadline_manager.delete_additional_queues() - if self.deadline_manager.queue_id: - self.deadline_manager.delete_queue() - if self.deadline_manager.farm_id: - self.deadline_manager.delete_farm() - if self.deadline_manager.kms_key_metadata: - self.deadline_manager.delete_kms_key() self.empty_bucket() + self.stack.destroy(cfn_client=self.cfn_client) + if self.queue: + self.queue.delete(client=self.deadline_client) + if self.farm: + self.farm.delete(client=self.deadline_client) diff --git a/src/deadline_test_scaffolding/models.py b/src/deadline_test_scaffolding/models.py new file mode 100644 index 0000000..3b9e296 --- /dev/null +++ b/src/deadline_test_scaffolding/models.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass + + +@dataclass(frozen=True) +class JobAttachmentSettings: + bucket_name: str + root_prefix: str + + def as_queue_settings(self) -> dict: + return { + "s3BucketName": self.bucket_name, + "rootPrefix": self.root_prefix, + } + + +@dataclass(frozen=True) +class CodeArtifactRepositoryInfo: + region: str + domain: str + domain_owner: str + repository: str + + @property + def domain_arn(self) -> str: + return f"arn:aws:codeartifact:{self.region}:{self.domain_owner}:domain/{self.domain}" + + @property + def repository_arn(self) -> str: + return f"arn:aws:codeartifact:{self.region}:{self.domain_owner}:repository/{self.domain}/{self.repository}" + + +@dataclass(frozen=True) +class S3Object: + bucket: str + key: str + + @staticmethod + def from_uri(uri: str) -> S3Object: + match = re.match(r"s3://(.+?)/(.+)", uri) + assert isinstance(match, re.Match), f"Cannot retrieve S3 bucket and key from URI: {uri}" + bucket, key = match.groups() + return S3Object( + bucket=bucket, + key=key, + ) + + @property + def arn(self) -> str: + return f"arn:aws:s3:::{self.bucket}/{self.key}" + + @property + def uri(self) -> str: + return f"s3://{self.bucket}/{self.key}" + + +@dataclass(frozen=True) +class ServiceModel: + install_command: str + + @staticmethod + def from_s3( + *, + object: S3Object, + local_filename: str, + service_name: str | None = None, + ) -> ServiceModel: + cmd = " && ".join( + [ + f"aws s3 cp {object.uri} {local_filename}", + # fmt: off + " ".join( + [ + "aws", + "configure", + "add-model", + "--service-model", + f"file://{local_filename}", + *(["--service-name", service_name] if service_name else []), + ] + ), + # fmt: on + ] + ) + return ServiceModel(install_command=cmd) + + @staticmethod + def from_local_file(*, local_file_path: str, service_name: str | None = None) -> ServiceModel: + cmd = " ".join( + [ + "aws", + "configure", + "add-model", + "--service-model", + f"file://{local_file_path}", + *(["--service-name", service_name] if service_name else []), + ] + ) + return ServiceModel(cmd) diff --git a/src/deadline_test_scaffolding/pytest_hooks.py b/src/deadline_test_scaffolding/pytest_hooks.py new file mode 100644 index 0000000..00f88bd --- /dev/null +++ b/src/deadline_test_scaffolding/pytest_hooks.py @@ -0,0 +1,45 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +import logging as _logging +from typing import Optional as _Optional + +import pytest as _pytest + +_root_logger = _logging.getLogger() +_log_filters: dict[str, _logging.Filter] = {} + + +class _PytestIdLoggerFilter(_logging.Filter): + """Filter that prepends pytest IDs to logs""" + + def __init__(self, test_id: str) -> None: + self.test_id = test_id + + def filter(self, record) -> bool: + record.msg = f"[{self.test_id}] {record.msg}" + return True + + +def pytest_sessionstart(session: _pytest.Session): + # Base logging configuration + formatter = _logging.Formatter("[%(asctime)s] %(message)s") + for handler in _root_logger.handlers: + handler.setFormatter(formatter) + + +def pytest_runtest_logstart(nodeid: str, location: tuple[str, _Optional[int], str]): + # Apply test ID log filter + log_filter = _PytestIdLoggerFilter(nodeid) + for handler in _root_logger.handlers: + handler.addFilter(log_filter) + _log_filters[nodeid] = log_filter + + +@_pytest.hookimpl(wrapper=True) +def pytest_runtest_teardown(item: _pytest.Item, nextitem: _Optional[_pytest.Item]): + # Remove test ID log filter + log_filter = _log_filters.pop(item.nodeid, None) + if log_filter: + for handler in _root_logger.handlers: + handler.removeFilter(log_filter) + + yield diff --git a/src/deadline_test_scaffolding/util.py b/src/deadline_test_scaffolding/util.py new file mode 100644 index 0000000..b417d54 --- /dev/null +++ b/src/deadline_test_scaffolding/util.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import botocore.exceptions +import logging +from time import sleep +from typing import Any, Callable + +LOG = logging.getLogger(__name__) + + +def wait_for( + *, + description: str, + predicate: Callable[[], bool], + interval_s: float, + max_retries: int | None = None, +) -> None: + if max_retries is not None: + assert max_retries > 0, "max_retries must be a positive integer" + assert interval_s > 0, "interval_s must be a positive number" + + LOG.info(f"Waiting for {description}") + retry_count = 0 + while not predicate(): + if max_retries and retry_count >= max_retries: + raise TimeoutError(f"Timed out waiting for {description}") + + LOG.info(f"Retrying in {interval_s}s...") + retry_count += 1 + sleep(interval_s) + + +def call_api(*, description: str, fn: Callable[[], Any]) -> Any: + LOG.info(f"About to call API ({description})") + try: + response = fn() + except botocore.exceptions.ClientError as e: + LOG.error(f"API call failed ({description})") + LOG.exception(f"The following exception was raised: {e}") + raise + else: + LOG.info(f"API call succeeded ({description})") + return response + + +def clean_kwargs(kwargs: dict) -> dict: + """Removes None from kwargs dicts""" + return {k: v for k, v in kwargs.items() if v is not None} diff --git a/test/unit/__init__.py b/test/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit/cloudformation/__init__.py b/test/unit/cloudformation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit/cloudformation/test_cfn.py b/test/unit/cloudformation/test_cfn.py new file mode 100644 index 0000000..01f6d9d --- /dev/null +++ b/test/unit/cloudformation/test_cfn.py @@ -0,0 +1,288 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from typing import Generator +from unittest.mock import MagicMock + +import boto3 +import pytest +from botocore.client import BaseClient +from botocore.exceptions import ClientError +from moto import mock_cloudformation + +from deadline_test_scaffolding.cloudformation.cfn import ( + CfnResource, + CfnStack, +) + + +class TestCfnStack: + @pytest.fixture(autouse=True) + def config_caplog(self, caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level("INFO") + + @pytest.fixture + def stack(self) -> CfnStack: + return CfnStack( + name="TestStack", + description="TestStack", + ) + + @pytest.fixture(autouse=True) + def resource(self, stack: CfnStack) -> CfnResource: + # Just use a bucket since moto actually "deploys" the stack, and: + # - Custom resources actually need a backing lambda + # - Unknown resource types are not supported (i.e. anything not Custom:: or a supported AWS:: type) + return CfnResource(stack, "AWS::S3::Bucket", "Bucket", {}) + + @pytest.fixture + def cfn_client(self) -> Generator[BaseClient, None, None]: + with mock_cloudformation(): + yield boto3.client("cloudformation") + + class TestDeploy: + def test_deploys_stack( + self, + stack: CfnStack, + cfn_client: BaseClient, + resource: CfnResource, + caplog: pytest.LogCaptureFixture, + ) -> None: + # WHEN + stack.deploy(cfn_client=cfn_client) + + # THEN + assert f"Stack {stack.name} does not exist yet. Creating new stack." in caplog.text + stacks = cfn_client.describe_stacks(StackName=stack.name)["Stacks"] + assert len(stacks) == 1 + assert stacks[0]["StackName"] == stack.name + resources = cfn_client.describe_stack_resources(StackName=stack.name)["StackResources"] + assert len(resources) == 1 + assert resources[0]["LogicalResourceId"] == resource.logical_name + + def test_updates_stack( + self, + stack: CfnStack, + resource: CfnResource, + cfn_client: BaseClient, + caplog: pytest.LogCaptureFixture, + ) -> None: + # GIVEN + stack.deploy(cfn_client=cfn_client) + resource2 = CfnResource(stack, "AWS::S3::Bucket", "Bucket2", {}) + + # WHEN + stack.deploy(cfn_client=cfn_client) + + # THEN + assert "Stack update complete" in caplog.text + resources = cfn_client.describe_stack_resources(StackName=stack.name)["StackResources"] + assert len(resources) == 2 + resources_logical_ids = [r["LogicalResourceId"] for r in resources] + assert resource.logical_name in resources_logical_ids + assert resource2.logical_name in resources_logical_ids + + def test_does_not_raise_when_stack_is_up_to_date( + self, + stack: CfnStack, + caplog: pytest.LogCaptureFixture, + ) -> None: + # GIVEN + # moto doesn't raise a ValidationError like CloudFormation does when calling UpdateStack + # when there are no updates to be performed, so setup a mock instead + mock_client = MagicMock() + mock_client.update_stack.side_effect = ClientError( + {"Error": {"Message": "No updates are to be performed."}}, + "UpdateStack", + ) + + try: + # WHEN + stack.deploy(cfn_client=mock_client) + except Exception as e: + pytest.fail(f"Stack.deploy() raised an error when it shouldn't have: {e}") + else: + # THEN + assert "Stack is already up to date" in caplog.text + + @pytest.mark.parametrize( + "error", + [ + ClientError({"Error": {"Message": "test"}}, None), + Exception(), + ], + ) + def test_raises_other_errors( + self, + error: Exception, + stack: CfnStack, + caplog: pytest.LogCaptureFixture, + ) -> None: + # GIVEN + mock_client = MagicMock() + mock_client.update_stack.side_effect = error + + with pytest.raises(type(error)) as raised_err: + # WHEN + stack.deploy(cfn_client=mock_client) + + # THEN + assert raised_err.value is error + if isinstance(error, ClientError): + assert ( + f"Unexpected error when attempting to update stack {stack.name}: " + in caplog.text + ) + + def test_destroy( + self, + stack: CfnStack, + cfn_client: BaseClient, + ) -> None: + # GIVEN + stack.deploy(cfn_client=cfn_client) + stacks = cfn_client.describe_stacks(StackName=stack.name)["Stacks"] + assert len(stacks) == 1 + assert stacks[0]["StackStatus"] == "CREATE_COMPLETE" + + # WHEN + stack.destroy(cfn_client=cfn_client) + + # THEN + # DescribeStacks API required Stack ID for deleted stacks, so just use ListStacks + stacks = cfn_client.list_stacks()["StackSummaries"] + assert len(stacks) == 1 + assert stacks[0]["StackStatus"] == "DELETE_COMPLETE" + + def test_template(self) -> None: + # GIVEN + stack = CfnStack(name="TheStack", description="TheDescription") + resources = [ + CfnResource(stack, "AWS::S3::Bucket", "BucketA", {"A": "a"}), + CfnResource(stack, "AWS::S3::Bucket", "BucketB", {"B": "b"}), + CfnResource(stack, "AWS::S3::Bucket", "BucketC", {"C": "c"}), + ] + + # WHEN + template = stack.template + + # THEN + assert template["Description"] == stack.description + for resource in resources: + assert resource.logical_name in template["Resources"] + assert resource.template == template["Resources"][resource.logical_name] + + +class TestCfnResource: + @pytest.fixture + def stack(self) -> CfnStack: + return CfnStack(name="TestStack") + + class TestTemplate: + @pytest.fixture + def resource_type(self) -> str: + return "Test::Resource::Type" + + @pytest.fixture + def resource_props(self) -> dict: + return { + "PropertyA": "ValueA", + "PropertyB": { + "A": "a", + "B": "b", + }, + } + + @pytest.fixture + def resource( + self, + stack: CfnStack, + resource_type: str, + resource_props: dict, + ) -> CfnResource: + # GIVEN + return CfnResource(stack, resource_type, "TestResource", resource_props) + + def test_defaults( + self, resource: CfnResource, resource_type: str, resource_props: dict + ) -> None: + # THEN + assert resource.template == { + "Type": resource_type, + "Properties": resource_props, + } + + def test_applies_update_replace_policy(self, resource: CfnResource) -> None: + # GIVEN + resource.update_replace_policy = "Retain" + + # THEN + assert resource.template["UpdateReplacePolicy"] == "Retain" + + def test_applies_deletion_policy(self, resource: CfnResource) -> None: + # GIVEN + resource.deletion_policy = "Retain" + + # THEN + assert resource.template["DeletionPolicy"] == "Retain" + + class TestPhysicalName: + def test_gets_physical_name(self, stack: CfnStack) -> None: + # GIVEN + class TestResource(CfnResource): + _physical_name_prop = "TestName" + + resource = TestResource( + stack, + "Test::Resource::Type", + "TestResource", + { + "TestName": "PhysicalName", + }, + ) + + # THEN + assert resource.physical_name == "PhysicalName" + + def test_raises_when_no_physical_name(self, stack: CfnStack) -> None: + # GIVEN + resource = CfnResource( + stack, + "Test::Resource::Type", + "TestResource", + { + "TestName": "PhysicalName", + }, + ) + + with pytest.raises(ValueError) as raised_err: + # WHEN + resource.physical_name + + # THEN + assert ( + str(raised_err.value) + == "Resource type Test::Resource::Type does not have a physical name" + ) + + def test_raises_when_physical_name_not_in_properties(self, stack: CfnStack) -> None: + # GIVEN + class TestResource(CfnResource): + _physical_name_prop = "TestName" + + resource = TestResource(stack, "Test::Resource::Type", "TestResource", {}) + + with pytest.raises(ValueError) as raised_err: + # WHEN + resource.physical_name + + # THEN + assert ( + str(raised_err.value) + == "Physical name was not specified for this resource (TestResource)" + ) + + def test_init_adds_to_stack(self, stack: CfnStack) -> None: + # WHEN + resource = CfnResource(stack, "Test::Resource::Type", "TestResource", {}) + + # THEN + assert resource in stack._resources diff --git a/test/unit/conftest.py b/test/unit/conftest.py index 77dc4cc..0fb0358 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -1,26 +1,19 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from typing import Generator -from unittest import mock from unittest.mock import patch import pytest -from deadline_test_scaffolding import DeadlineManager - -@pytest.fixture() -def mock_get_deadline_models(): - with mock.patch.object(DeadlineManager, "get_deadline_models") as mocked_get_deadline_models: - yield mocked_get_deadline_models - - -@pytest.fixture(scope="function") -def boto_config() -> Generator[None, None, None]: - updated_environment = { - "AWS_ACCESS_KEY_ID": "ACCESSKEY", - "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", - "AWS_DEFAULT_REGION": "us-west-2", +@pytest.fixture(scope="function", autouse=True) +def boto_config() -> Generator[dict[str, str], None, None]: + config = { + "AWS_ACCESS_KEY_ID": "testing", + "AWS_SECRET_ACCESS_KEY": "testing", + "AWS_SECURITY_TOKEN": "testing", + "AWS_SESSION_TOKEN": "testing", + "AWS_DEFAULT_REGION": "us-east-1", } - with patch.dict("os.environ", updated_environment): - yield + with patch.dict("os.environ", config): + yield config diff --git a/test/unit/deadline/__init__.py b/test/unit/deadline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit/test_deadline_shim.py b/test/unit/deadline/test_client.py similarity index 95% rename from test/unit/test_deadline_shim.py rename to test/unit/deadline/test_client.py index 5655328..b164bac 100644 --- a/test/unit/test_deadline_shim.py +++ b/test/unit/deadline/test_client.py @@ -2,28 +2,30 @@ import pytest from unittest.mock import MagicMock, patch from deadline_test_scaffolding import DeadlineClient -from shared_constants import MOCK_FARM_NAME, MOCK_FLEET_NAME, MOCK_QUEUE_NAME +MOCK_FARM_NAME = "test-farm" +MOCK_FLEET_NAME = "test-fleet" +MOCK_QUEUE_NAME = "test-queue" -class FakeClient: - def fake_deadline_client_has_this(self) -> str: - return "from fake client" - def but_not_this(self) -> str: - return "from fake client" - - -class FakeDeadlineClient(DeadlineClient): - def fake_deadline_client_has_this(self) -> str: - return "from fake deadline client" - - -class TestDeadlineShim: +class TestDeadlineClient: def test_deadline_client_pass_through(self) -> None: """ Confirm that DeadlineClient passes through unknown methods to the underlying client but just executes known methods. """ + + class FakeClient: + def fake_deadline_client_has_this(self) -> str: + return "from fake client" + + def but_not_this(self) -> str: + return "from fake client" + + class FakeDeadlineClient(DeadlineClient): + def fake_deadline_client_has_this(self) -> str: + return "from fake deadline client" + fake_client = FakeClient() deadline_client = FakeDeadlineClient(fake_client) diff --git a/test/unit/deadline/test_resources.py b/test/unit/deadline/test_resources.py new file mode 100644 index 0000000..af56ebc --- /dev/null +++ b/test/unit/deadline/test_resources.py @@ -0,0 +1,640 @@ +import datetime +import json +from dataclasses import replace +from typing import Any, Generator, cast +from unittest.mock import MagicMock, call, patch + +import pytest + +from deadline_test_scaffolding import ( + Farm, + Queue, + Fleet, + QueueFleetAssociation, + Job, + JobAttachmentSettings, + TaskStatus, +) +from deadline_test_scaffolding.deadline import resources as mod + + +@pytest.fixture(autouse=True) +def wait_for_shim() -> Generator[None, None, None]: + import sys + from deadline_test_scaffolding.util import wait_for + + # Force the wait_for to have a short interval for unit tests + def wait_for_shim(*args, **kwargs): + kwargs.pop("interval_s", None) + kwargs.pop("max_retries", None) + wait_for(*args, **kwargs, interval_s=sys.float_info.epsilon, max_retries=None) + + with patch.object(mod, "wait_for", wait_for_shim): + yield + + +@pytest.fixture +def farm() -> Farm: + return Farm(id="farm-123") + + +@pytest.fixture +def queue(farm: Farm) -> Queue: + return Queue(id="queue-123", farm=farm) + + +@pytest.fixture +def fleet(farm: Farm) -> Fleet: + return Fleet(id="fleet-123", farm=farm) + + +@pytest.fixture +def qfa(farm: Farm, queue: Queue, fleet: Fleet) -> QueueFleetAssociation: + return QueueFleetAssociation( + farm=farm, + queue=queue, + fleet=fleet, + ) + + +class TestFarm: + def test_create(self) -> None: + # GIVEN + display_name = "test-farm" + farm_id = "farm-123" + mock_client = MagicMock() + mock_client.create_farm.return_value = {"farmId": farm_id} + + # WHEN + result = Farm.create(client=mock_client, display_name=display_name) + + # THEN + assert result.id == farm_id + mock_client.create_farm.assert_called_once_with( + displayName=display_name, + ) + + def test_delete(self, farm: Farm) -> None: + # GIVEN + mock_client = MagicMock() + + # WHEN + farm.delete(client=mock_client) + + # THEN + mock_client.delete_farm.assert_called_once_with(farmId=farm.id) + + +class TestQueue: + def test_create(self, farm: Farm) -> None: + # GIVEN + display_name = "test-queue" + queue_id = "queue-123" + role_arn = "arn:aws:iam::123456789123:role/TestRole" + job_attachments = JobAttachmentSettings(bucket_name="bucket", root_prefix="root") + mock_client = MagicMock() + mock_client.create_queue.return_value = {"queueId": queue_id} + + # WHEN + result = Queue.create( + client=mock_client, + display_name=display_name, + farm=farm, + role_arn=role_arn, + job_attachments=job_attachments, + ) + + # THEN + assert result.id == queue_id + mock_client.create_queue.assert_called_once_with( + displayName=display_name, + farmId=farm.id, + roleArn=role_arn, + jobAttachmentSettings=job_attachments.as_queue_settings(), + ) + + def test_delete(self, queue: Queue) -> None: + # GIVEN + mock_client = MagicMock() + + # WHEN + queue.delete(client=mock_client) + + # THEN + mock_client.delete_queue.assert_called_once_with(queueId=queue.id, farmId=queue.farm.id) + + +class TestFleet: + def test_create(self, farm: Farm) -> None: + # GIVEN + display_name = "test-fleet" + fleet_id = "fleet-123" + configuration: dict = {} + role_arn = "arn:aws:iam::123456789123:role/TestRole" + mock_client = MagicMock() + mock_client.create_fleet.return_value = {"fleetId": fleet_id} + + # WHEN + with patch.object(Fleet, "wait_for_desired_status") as mock_wait_for_desired_status: + result = Fleet.create( + client=mock_client, + display_name=display_name, + farm=farm, + configuration=configuration, + role_arn=role_arn, + ) + + # THEN + assert result.id == fleet_id + mock_client.create_fleet.assert_called_once_with( + farmId=farm.id, + displayName=display_name, + roleArn=role_arn, + configuration=configuration, + ) + mock_wait_for_desired_status.assert_called_once_with( + client=mock_client, + desired_status="ACTIVE", + allowed_statuses=set(["CREATE_IN_PROGRESS"]), + ) + + def test_delete(self, fleet: Fleet) -> None: + # GIVEN + mock_client = MagicMock() + + # WHEN + fleet.delete(client=mock_client) + + # THEN + mock_client.delete_fleet.assert_called_once_with(fleetId=fleet.id, farmId=fleet.farm.id) + + class TestWaitForDesiredStatus: + def test_waits(self, fleet: Fleet) -> None: + # GIVEN + desired_status = "ACTIVE" + allowed_statuses = set(["CREATE_IN_PROGRESS"]) + mock_client = MagicMock() + mock_client.get_fleet.side_effect = [ + {"status": "CREATE_IN_PROGRESS"}, + {"status": "ACTIVE"}, + ] + + # WHEN + fleet.wait_for_desired_status( + client=mock_client, + desired_status=desired_status, + allowed_statuses=allowed_statuses, + ) + + # THEN + mock_client.get_fleet.assert_has_calls( + [call(fleetId=fleet.id, farmId=fleet.farm.id)] * 2 + ) + + def test_raises_when_nonvalid_status_is_reached(self, fleet: Fleet) -> None: + # GIVEN + desired_status = "ACTIVE" + allowed_statuses = set(["CREATE_IN_PROGRESS"]) + mock_client = MagicMock() + mock_client.get_fleet.side_effect = [ + {"status": "BAD"}, + ] + + with pytest.raises(ValueError) as raised_err: + # WHEN + fleet.wait_for_desired_status( + client=mock_client, + desired_status=desired_status, + allowed_statuses=allowed_statuses, + ) + + # THEN + assert ( + str(raised_err.value) + == "fleet entered a nonvalid status (BAD) while waiting for the desired status: ACTIVE" + ) + + +class TestQueueFleetAssociation: + def test_create(self, farm: Farm, queue: Queue, fleet: Fleet) -> None: + # GIVEN + mock_client = MagicMock() + + # WHEN + QueueFleetAssociation.create( + client=mock_client, + farm=farm, + queue=queue, + fleet=fleet, + ) + + # THEN + mock_client.create_queue_fleet_association.assert_called_once_with( + farmId=farm.id, + queueId=queue.id, + fleetId=fleet.id, + ) + + @pytest.mark.parametrize( + "stop_mode", + [ + "STOP_SCHEDULING_AND_CANCEL_TASKS", + "STOP_SCHEDULING_AND_FINISH_TASKS", + ], + ) + def test_delete(self, stop_mode: Any, qfa: QueueFleetAssociation) -> None: + # GIVEN + mock_client = MagicMock() + + # WHEN + with patch.object(qfa, "stop") as mock_stop: + qfa.delete(client=mock_client, stop_mode=stop_mode) + + # THEN + mock_client.delete_queue_fleet_association.assert_called_once_with( + farmId=qfa.farm.id, + queueId=qfa.queue.id, + fleetId=qfa.fleet.id, + ) + mock_stop.assert_called_once_with( + client=mock_client, + stop_mode=stop_mode, + ) + + class TestStop: + @pytest.mark.parametrize( + "stop_mode", + [ + "STOP_SCHEDULING_AND_CANCEL_TASKS", + "STOP_SCHEDULING_AND_FINISH_TASKS", + ], + ) + def test_stops(self, stop_mode: Any, qfa: QueueFleetAssociation) -> None: + # GIVEN + mock_client = MagicMock() + mock_client.get_queue_fleet_association.side_effect = [ + {"status": stop_mode}, + {"status": "STOPPED"}, + ] + + # WHEN + qfa.stop(client=mock_client, stop_mode=stop_mode) + + # THEN + mock_client.update_queue_fleet_association.assert_called_once_with( + farmId=qfa.farm.id, + queueId=qfa.queue.id, + fleetId=qfa.fleet.id, + status=stop_mode, + ) + mock_client.get_queue_fleet_association.assert_has_calls( + [call(farmId=qfa.farm.id, queueId=qfa.queue.id, fleetId=qfa.fleet.id)] * 2 + ) + + def test_raises_when_nonvalid_status_is_reached(self, qfa: QueueFleetAssociation) -> None: + # GIVEN + mock_client = MagicMock() + mock_client.get_queue_fleet_association.side_effect = [ + {"status": "BAD"}, + ] + + with pytest.raises(ValueError) as raised_err: + # WHEN + qfa.stop( + client=mock_client, + stop_mode="STOP_SCHEDULING_AND_CANCEL_TASKS", + ) + + # THEN + assert ( + str(raised_err.value) + == "Association entered a nonvalid status (BAD) while waiting for the desired status: STOPPED" + ) + + +class TestJob: + @staticmethod + def task_run_status_counts( + pending: int = 0, + ready: int = 0, + assigned: int = 0, + scheduled: int = 0, + interrupting: int = 0, + running: int = 0, + suspended: int = 0, + canceled: int = 0, + failed: int = 0, + succeeded: int = 0, + ) -> dict: + return { + "PENDING": pending, + "READY": ready, + "ASSIGNED": assigned, + "SCHEDULED": scheduled, + "INTERRUPTING": interrupting, + "RUNNING": running, + "SUSPENDED": suspended, + "CANCELED": canceled, + "FAILED": failed, + "SUCCEEDED": succeeded, + } + + @pytest.fixture + def job(self, farm: Farm, queue: Queue) -> Job: + return Job( + id="job-123", + farm=farm, + queue=queue, + template={}, + name="Job Name", + lifecycle_status="CREATE_COMPLETE", + lifecycle_status_message="Nice", + priority=1, + created_at=datetime.datetime.now(), + created_by="test-user", + ) + + def test_submit( + self, + farm: Farm, + queue: Queue, + ) -> None: + """ + Verifies that Job.submit creates the job, retrieves its details, and returns the expected Job object. + Note that for the GetJob call, only those that are relevant for Job creation are included. Full testing + of GetJob calls is covered by the test for Job.get_job_details + """ + # GIVEN + # CreateJob parameters + template = { + "specificationVersion": "2022-09-01", + "name": "Test Job", + "parameters": [ + { + "name": "Text", + "type": "STRING", + }, + ], + "steps": [ + { + "name": "Step0", + "script": { + "actions": { + "onRun": {"command": "/bin/echo", "args": [r"{{ Param.Text }}"]} + }, + }, + }, + ], + } + priority = 1 + parameters = {"Text": {"string": "Hello world"}} + target_task_run_status = "SUSPENDED" + max_failed_tasks_count = 0 + max_retries_per_task = 0 + + mock_client = MagicMock() + + # CreateJob mock + job_id = "job-123" + mock_client.create_job.return_value = {"jobId": job_id} + + # GetJob mock + task_run_status_counts = TestJob.task_run_status_counts( + **{target_task_run_status.lower(): 1} + ) + created_at = datetime.datetime.now() + mock_client.get_job.return_value = { + "jobId": job_id, + "name": "Test Job", + "lifecycleStatus": "CREATE_COMPLETE", + "lifecycleStatusMessage": "Nice", + "priority": priority, + "createdAt": created_at, + "createdBy": "test-user", + "taskRunStatus": target_task_run_status, + "taskRunStatusCounts": task_run_status_counts, + "maxFailedTasksCount": max_failed_tasks_count, + "maxRetriesPerTask": max_retries_per_task, + "parameters": parameters, + } + + # WHEN + job = Job.submit( + client=mock_client, + farm=farm, + queue=queue, + template=template, + priority=priority, + parameters=parameters, + target_task_run_status=target_task_run_status, + max_failed_tasks_count=max_failed_tasks_count, + max_retries_per_task=max_retries_per_task, + ) + + # THEN + mock_client.create_job.assert_called_once_with( + farmId=farm.id, + queueId=queue.id, + template=json.dumps(template), + templateType="JSON", + priority=priority, + parameters=parameters, + targetTaskRunStatus=target_task_run_status, + maxFailedTasksCount=max_failed_tasks_count, + maxRetriesPerTask=max_retries_per_task, + ) + mock_client.get_job.assert_called_once_with( + farmId=farm.id, + queueId=queue.id, + jobId=job_id, + ) + assert job.id == job_id + assert job.farm is farm + assert job.queue is queue + assert job.template == template + assert job.name == "Test Job" + assert job.lifecycle_status == "CREATE_COMPLETE" + assert job.lifecycle_status_message == "Nice" + assert job.priority == priority + assert job.created_at == created_at + assert job.created_by == "test-user" + assert job.task_run_status == target_task_run_status + assert job.task_run_status_counts == task_run_status_counts + assert job.max_failed_tasks_count == max_failed_tasks_count + assert job.max_retries_per_task == max_retries_per_task + assert job.parameters == parameters + + def test_get_job_details(self, farm: Farm, queue: Queue) -> None: + """ + Verifies that Job.get_job_details correctly maps the GetJob response to + kwargs that are compatible with Job.__init__ + """ + # GIVEN + now = datetime.datetime.now() + job_id = "job-123" + response = { + "jobId": job_id, + "name": "Job Name", + "lifecycleStatus": "CREATE_COMPLETE", + "lifecycleStatusMessage": "Nice", + "priority": 1, + "createdAt": now - datetime.timedelta(hours=1), + "createdBy": "User A", + "updatedAt": now - datetime.timedelta(minutes=30), + "updatedBy": "User B", + "startedAt": now - datetime.timedelta(minutes=15), + "endedAt": now - datetime.timedelta(minutes=1), + # Just need to test all fields.. this is not reflective of a real job state + "taskRunStatus": "RUNNING", + "targetTaskRunStatus": "SUCCEEDED", + "taskRunStatusCounts": TestJob.task_run_status_counts(running=2, succeeded=8), + "storageProfileId": "storage-profile-id-123", + "maxFailedTasksCount": 3, + "maxRetriesPerTask": 1, + "parameters": {"ParamA": {"int": "1"}}, + "attachments": { + "manifests": [ + { + "rootPath": "/root", + "osType": "linux", + }, + ], + "assetLoadingMethod": "PRELOAD", + }, + "description": "Testing", + } + mock_client = MagicMock() + mock_client.get_job.return_value = response + + # WHEN + kwargs = Job.get_job_details(client=mock_client, farm=farm, queue=queue, job_id=job_id) + + # THEN + # Verify kwargs are parsed/transformed correctly + assert kwargs["id"] == job_id + assert kwargs["name"] == response["name"] + assert kwargs["lifecycle_status"] == response["lifecycleStatus"] + assert kwargs["lifecycle_status_message"] == response["lifecycleStatusMessage"] + assert kwargs["priority"] == response["priority"] + assert kwargs["created_at"] == response["createdAt"] + assert kwargs["created_by"] == response["createdBy"] + assert kwargs["updated_at"] == response["updatedAt"] + assert kwargs["updated_by"] == response["updatedBy"] + assert kwargs["started_at"] == response["startedAt"] + assert kwargs["ended_at"] == response["endedAt"] + assert kwargs["task_run_status"] == TaskStatus[cast(str, response["taskRunStatus"])] + assert ( + kwargs["target_task_run_status"] + == TaskStatus[cast(str, response["targetTaskRunStatus"])] + ) + assert kwargs["task_run_status_counts"] == { + TaskStatus[k]: v for k, v in cast(dict, response["taskRunStatusCounts"]).items() + } + assert kwargs["storage_profile_id"] == response["storageProfileId"] + assert kwargs["max_failed_tasks_count"] == response["maxFailedTasksCount"] + assert kwargs["max_retries_per_task"] == response["maxRetriesPerTask"] + assert kwargs["parameters"] == response["parameters"] + assert kwargs["attachments"] == response["attachments"] + assert kwargs["description"] == response["description"] + + # Verify Job.__init__ accepts the kwargs + try: + Job(farm=farm, queue=queue, template={}, **kwargs) + except TypeError as e: + pytest.fail(f"Job.__init__ did not accept kwargs from Job.get_job_details: {e}") + else: + # Success + pass + + def test_refresh_job_info(self, job: Job) -> None: + # GIVEN + # Copy job object for comparison later + original_job = replace(job) + + # Mock GetJob + get_job_response = { + "jobId": job.id, + "name": job.name, + "lifecycleStatus": job.lifecycle_status, + "lifecycleStatusMessage": job.lifecycle_status_message, + "createdAt": job.created_at, + "createdBy": job.created_by, + # Change the priority + "priority": 2, + } + mock_client = MagicMock() + mock_client.get_job.return_value = get_job_response + + # WHEN + job.refresh_job_info(client=mock_client) + + # THEN + mock_client.get_job.assert_called_once() + + # Verify priority changed... + assert job.priority == 2 + + # ...and everything else stayed the same + assert job.name == original_job.name + assert job.lifecycle_status == original_job.lifecycle_status + assert job.lifecycle_status_message == original_job.lifecycle_status_message + assert job.created_at == original_job.created_at + assert job.created_by == original_job.created_by + + def test_update(self, job: Job) -> None: + # GIVEN + priority = 24 + target_task_run_status = "READY" + max_failed_tasks_count = 1 + max_retries_per_task = 10 + mock_client = MagicMock() + + # WHEN + job.update( + client=mock_client, + priority=priority, + target_task_run_status=target_task_run_status, + max_failed_tasks_count=max_failed_tasks_count, + max_retries_per_task=max_retries_per_task, + ) + + # THEN + mock_client.update_job.assert_called_once_with( + farmId=job.farm.id, + queueId=job.queue.id, + jobId=job.id, + priority=priority, + targetTaskRunStatus=target_task_run_status, + maxFailedTasksCount=max_failed_tasks_count, + maxRetriesPerTask=max_retries_per_task, + ) + + def test_wait_until_complete(self, job: Job) -> None: + # GIVEN + common_response_kwargs = { + "jobId": job.id, + "name": job.name, + "lifecycleStatus": job.lifecycle_status, + "lifecycleStatusMessage": job.lifecycle_status_message, + "priority": job.priority, + "createdAt": job.created_at, + "createdBy": job.created_by, + } + mock_client = MagicMock() + mock_client.get_job.side_effect = [ + { + **common_response_kwargs, + "taskRunStatus": "RUNNING", + }, + { + **common_response_kwargs, + "taskRunStatus": "FAILED", + }, + ] + + # WHEN + job.wait_until_complete(client=mock_client, max_retries=1) + + # THEN + assert mock_client.get_job.call_count == 2 + assert job.task_run_status == "FAILED" diff --git a/test/unit/deadline/test_worker.py b/test/unit/deadline/test_worker.py new file mode 100644 index 0000000..2c2878c --- /dev/null +++ b/test/unit/deadline/test_worker.py @@ -0,0 +1,528 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +import json +import os +import pathlib +import re +import subprocess +from typing import Any, Generator +from unittest.mock import ANY, MagicMock, call, mock_open, patch + +import boto3 +import pytest +from botocore.exceptions import ClientError +from moto import mock_ec2, mock_iam, mock_s3, mock_ssm + +from deadline_test_scaffolding.deadline import worker as mod +from deadline_test_scaffolding import ( + CodeArtifactRepositoryInfo, + CommandResult, + DeadlineWorkerConfiguration, + DockerContainerWorker, + EC2InstanceWorker, + PipInstall, + S3Object, + ServiceModel, +) + + +@pytest.fixture(autouse=True) +def moto_mocks() -> Generator[None, None, None]: + with mock_ec2(), mock_iam(), mock_s3(), mock_ssm(): + yield + + +@pytest.fixture(autouse=True) +def mock_sleep() -> Generator[None, None, None]: + # We don't want to sleep in unit tests + with patch.object(mod.time, "sleep"): + yield + + +@pytest.fixture(autouse=True) +def wait_for_shim() -> Generator[None, None, None]: + import sys + from deadline_test_scaffolding.util import wait_for + + # Force the wait_for to have a short interval for unit tests + def wait_for_shim(*args, **kwargs): + kwargs.pop("interval_s", None) + kwargs.pop("max_retries", None) + wait_for(*args, **kwargs, interval_s=sys.float_info.epsilon, max_retries=None) + + with patch.object(mod, "wait_for", wait_for_shim): + yield + + +@pytest.fixture +def region(boto_config: dict[str, str]) -> str: + return boto_config["AWS_DEFAULT_REGION"] + + +@pytest.fixture +def deadline_client() -> MagicMock: + return MagicMock() + + +@pytest.fixture +def worker_config(region: str) -> DeadlineWorkerConfiguration: + return DeadlineWorkerConfiguration( + farm_id="farm-123", + fleet_id="fleet-123", + region=region, + user="test-user", + group="test-group", + allow_shutdown=False, + worker_agent_install=PipInstall( + requirement_specifiers=["deadline-cloud-worker-agent"], + codeartifact=CodeArtifactRepositoryInfo( + region=region, + domain="test-domain", + domain_owner="123456789123", + repository="test-repository", + ), + ), + file_mappings=[ + ("/tmp/file1.txt", "/home/test-user/file1.txt"), + ("/packages/manifest.json", "/tmp/manifest.json"), + ("/aws/models/deadline.json", "/tmp/deadline.json"), + ], + service_model=ServiceModel( + "aws configure add-model --service-model file:///tmp/deadline.json" + ), + ) + + +class TestEC2InstanceWorker: + @staticmethod + def describe_instance(instance_id: str) -> Any: + ec2_client = boto3.client("ec2") + response = ec2_client.describe_instances(InstanceIds=[instance_id]) + + reservations = response["Reservations"] + assert len(reservations) == 1 + + instances = reservations[0]["Instances"] + assert len(instances) == 1 + + return instances[0] + + @pytest.fixture + def vpc_id(self) -> str: + return boto3.client("ec2").create_vpc(CidrBlock="10.0.0.0/28")["Vpc"]["VpcId"] + + @pytest.fixture + def subnet_id(self, vpc_id: str) -> str: + return boto3.client("ec2").create_subnet( + VpcId=vpc_id, + CidrBlock="10.0.0.0/28", + )[ + "Subnet" + ]["SubnetId"] + + @pytest.fixture + def security_group_id(self, vpc_id: str) -> str: + return boto3.client("ec2").create_security_group( + VpcId=vpc_id, + Description="Testing", + GroupName="TestSG", + )["GroupId"] + + @pytest.fixture + def instance_profile(self) -> Any: + return boto3.client("iam").create_instance_profile(InstanceProfileName="instance-profile")[ + "InstanceProfile" + ] + + @pytest.fixture + def instance_profile_name(self, instance_profile: Any) -> str: + return instance_profile["InstanceProfileName"] + + @pytest.fixture + def bootstrap_bucket_name(self, region: str) -> str: + name = "bootstrap-bucket" + kwargs: dict[str, Any] = {"Bucket": name} + if region != "us-east-1": + kwargs["CreateBucketConfiguration"] = {"LocationConstraint": region} + boto3.client("s3").create_bucket(**kwargs) + return name + + @pytest.fixture + def worker( + self, + deadline_client: MagicMock, + worker_config: DeadlineWorkerConfiguration, + subnet_id: str, + security_group_id: str, + instance_profile_name: str, + bootstrap_bucket_name: str, + ) -> EC2InstanceWorker: + return EC2InstanceWorker( + subnet_id=subnet_id, + security_group_id=security_group_id, + instance_profile_name=instance_profile_name, + bootstrap_bucket_name=bootstrap_bucket_name, + s3_client=boto3.client("s3"), + ec2_client=boto3.client("ec2"), + ssm_client=boto3.client("ssm"), + deadline_client=deadline_client, + configuration=worker_config, + ) + + @patch.object(mod, "open", mock_open(read_data="mock data".encode())) + def test_start(self, worker: EC2InstanceWorker) -> None: + # GIVEN + s3_files = [ + ("s3://bucket/key", "/tmp/key"), + ("s3://bucket/tmp/file", "/tmp/file"), + ] + with ( + patch.object(worker, "_stage_s3_bucket", return_value=s3_files) as mock_stage_s3_bucket, + patch.object(worker, "_launch_instance") as mock_launch_instance, + patch.object(worker, "_start_worker_agent") as mock_start_worker_agent, + ): + # WHEN + worker.start() + + # THEN + # Detailed testing for each of these is done in dedicated test methods + mock_stage_s3_bucket.assert_called_once() + mock_launch_instance.assert_called_once_with(s3_files=s3_files) + mock_start_worker_agent.assert_called_once() + + def test_stage_s3_bucket( + self, + worker: EC2InstanceWorker, + worker_config: DeadlineWorkerConfiguration, + bootstrap_bucket_name: str, + ) -> None: + # GIVEN + # We don't want to actually match real files, just limit src paths to absolute paths + with ( + patch.object(mod.glob, "glob", lambda path: [path]), + patch.object(mod, "open", mock_open(read_data="mock data".encode())), + ): + # WHEN + s3_files = worker._stage_s3_bucket() + + # THEN + # Verify mappings are correct + assert s3_files is not None and worker_config.file_mappings is not None + assert len(s3_files) == len(worker_config.file_mappings) + for src, dst in worker_config.file_mappings: + assert (f"s3://{bootstrap_bucket_name}/worker/{os.path.basename(src)}", dst) in s3_files + + # Verify files are uploaded to S3 + s3_client = boto3.client("s3") + for s3_uri, _ in s3_files: + s3_obj = S3Object.from_uri(s3_uri) + s3_client.head_object(Bucket=s3_obj.bucket, Key=s3_obj.key) + + def test_launch_instance( + self, + worker: EC2InstanceWorker, + vpc_id: str, + subnet_id: str, + security_group_id: str, + instance_profile: Any, + ) -> None: + # WHEN + worker._launch_instance() + + # THEN + assert worker.instance_id is not None + + instance = TestEC2InstanceWorker.describe_instance(worker.instance_id) + assert instance["ImageId"] == worker.ami_id + assert instance["State"]["Name"] == "running" + assert instance["SubnetId"] == subnet_id + assert instance["VpcId"] == vpc_id + assert instance["IamInstanceProfile"]["Arn"] == instance_profile["Arn"] + assert len(instance["SecurityGroups"]) == 1 + assert instance["SecurityGroups"][0]["GroupId"] == security_group_id + + @pytest.mark.skip( + "There's nothing to test in this method currently since it's just sending SSM commands" + ) + def test_start_worker_agent(self) -> None: + pass + + def test_stop(self, worker: EC2InstanceWorker) -> None: + # GIVEN + worker.start() + instance_id = worker.instance_id + assert instance_id is not None + + instance = TestEC2InstanceWorker.describe_instance(instance_id) + assert instance["State"]["Name"] == "running" + + # WHEN + worker.stop() + + # THEN + instance = TestEC2InstanceWorker.describe_instance(instance_id) + assert instance["State"]["Name"] == "terminated" + assert worker.instance_id is None + + class TestSendCommand: + def test_sends_command(self, worker: EC2InstanceWorker) -> None: + # GIVEN + cmd = 'echo "Hello world"' + worker.start() + + # WHEN + with patch.object( + worker.ssm_client, "send_command", wraps=worker.ssm_client.send_command + ) as send_command_spy: + worker.send_command(cmd) + + # THEN + send_command_spy.assert_called_once_with( + InstanceIds=[worker.instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [cmd]}, + ) + + def test_retries_when_instance_not_ready(self, worker: EC2InstanceWorker) -> None: + # GIVEN + cmd = 'echo "Hello world"' + worker.start() + real_send_command = worker.ssm_client.send_command + + call_count = 0 + + def side_effect(*args, **kwargs): + nonlocal call_count + if call_count < 1: + call_count += 1 + raise ClientError({"Error": {"Code": "InvalidInstanceId"}}, "SendCommand") + else: + return real_send_command(*args, **kwargs) + + # WHEN + with patch.object( + worker.ssm_client, "send_command", side_effect=side_effect + ) as mock_send_command: + worker.send_command(cmd) + + # THEN + mock_send_command.assert_has_calls( + [ + call( + InstanceIds=[worker.instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [cmd]}, + ) + ] + * 2 + ) + + def test_raises_any_other_error(self, worker: EC2InstanceWorker) -> None: + # GIVEN + cmd = 'echo "Hello world"' + worker.start() + err = ClientError({"Error": {"Code": "SomethingWentWrong"}}, "SendCommand") + + # WHEN + with pytest.raises(ClientError) as raised_err: + with patch.object( + worker.ssm_client, "send_command", side_effect=err + ) as mock_send_command: + worker.send_command(cmd) + + # THEN + assert raised_err.value is err + mock_send_command.assert_called_once_with( + InstanceIds=[worker.instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [cmd]}, + ) + + @pytest.mark.parametrize( + "worker_id", + [ + "worker-7c3377ec9eba444bb51cc7da18463081", + "worker-7c3377ec9eba444bb51cc7da18463081\n", + "worker-7c3377ec9eba444bb51cc7da18463081\r\n", + ], + ) + def test_worker_id(self, worker_id: str, worker: EC2InstanceWorker) -> None: + # GIVEN + with patch.object( + worker, "send_command", return_value=CommandResult(exit_code=0, stdout=worker_id) + ): + # WHEN + result = worker.worker_id + + # THEN + assert result == worker_id.rstrip("\n\r") + + def test_ami_id(self, worker: EC2InstanceWorker) -> None: + # WHEN + ami_id = worker.ami_id + + # THEN + assert re.match(r"^ami-[0-9a-f]{17}$", ami_id) + + +class TestDockerContainerWorker: + @pytest.fixture + def worker(self, worker_config: DeadlineWorkerConfiguration) -> DockerContainerWorker: + return DockerContainerWorker(configuration=worker_config) + + def test_start( + self, + worker: DockerContainerWorker, + worker_config: DeadlineWorkerConfiguration, + caplog: pytest.LogCaptureFixture, + ) -> None: + # GIVEN + caplog.set_level("INFO") + + # file_mappings + tmpdir = "/tmp" + + # subprocess.Popen("./run_container.sh") + run_container_stdout_lines = ["line1", "line2", ""] + mock_proc = MagicMock() + mock_proc.stdout.readline.side_effect = run_container_stdout_lines + mock_proc.wait.return_value = 0 + + # subprocess.check_output("cat .container_id") + container_id = "798914422427460f83827544bfca1d83" + + with ( + patch.object(mod, "shutil") as mock_shutil, + patch.object(mod.tempfile, "mkdtemp", return_value=tmpdir), + patch.object(mod.os, "makedirs") as mock_makedirs, + patch.object(mod.subprocess, "Popen") as mock_Popen, + patch.object(mod.subprocess, "check_output") as mock_check_output, + ): + mock_Popen.return_value = mock_proc + mock_check_output.return_value = container_id + + # WHEN + worker.start() + + # THEN + mock_shutil.copytree.assert_called_once_with(ANY, tmpdir, dirs_exist_ok=True) + + # Verify file_mappings dir is staged + file_mappings_dir = f"{tmpdir}/file_mappings" + mock_makedirs.assert_called_once_with(file_mappings_dir) + assert worker_config.file_mappings + for src, _ in worker_config.file_mappings: + mock_shutil.copyfile.assert_any_call( + src, f"{file_mappings_dir}/{os.path.basename(src)}" + ) + + # Verify subprocess.Popen("./run_container.sh") + _, popen_kwargs = mock_Popen.call_args + assert popen_kwargs["args"] == "./run_container.sh" + assert popen_kwargs["cwd"] == ANY + assert popen_kwargs["stdout"] == subprocess.PIPE + assert popen_kwargs["stderr"] == subprocess.STDOUT + assert popen_kwargs["text"] is True + assert popen_kwargs["encoding"] == "utf-8" + expected_env = { + "FILE_MAPPINGS": ANY, + "AGENT_USER": worker_config.user, + "SHARED_GROUP": worker_config.group, + "JOB_USER": "jobuser", + "CONFIGURE_WORKER_AGENT_CMD": ANY, + } + actual_env = popen_kwargs["env"] + for expected_key, expected_value in expected_env.items(): + assert expected_key in actual_env + assert actual_env[expected_key] == expected_value + assert all(line in caplog.text for line in run_container_stdout_lines) + mock_proc.wait.assert_called_once() + + # Verify FILE_MAPPINGS env var is generated correctly + actual_file_mappings = json.loads(actual_env["FILE_MAPPINGS"]) + for src, dst in worker_config.file_mappings: + docker_src = f"/file_mappings/{os.path.basename(src)}" + assert docker_src in actual_file_mappings + assert actual_file_mappings[docker_src] == dst + + # Verify subprocess.check_output("cat .container_id") + _, check_output_kwargs = mock_check_output.call_args + assert check_output_kwargs["args"] == ["cat", ".container_id"] + assert check_output_kwargs["cwd"] == ANY + assert check_output_kwargs["text"] is True + assert check_output_kwargs["encoding"] == "utf-8" + assert check_output_kwargs["timeout"] == 1 + assert worker.container_id == container_id + + def test_stop( + self, worker: DockerContainerWorker, worker_config: DeadlineWorkerConfiguration + ) -> None: + # GIVEN + container_id = "container-id" + worker._container_id = container_id + worker._tmpdir = pathlib.Path("/tmp") + + with ( + patch.object(worker, "send_command") as mock_send_command, + patch.object(mod.subprocess, "check_output") as mock_check_output, + ): + # WHEN + worker.stop() + + # THEN + assert worker.container_id is None + mock_send_command.assert_called_once_with(f"pkill --signal term -f {worker_config.user}") + mock_check_output.assert_called_once_with( + args=["docker", "container", "stop", container_id], + cwd=ANY, + text=True, + encoding="utf-8", + timeout=30, + ) + + def test_send_command(self, worker: DockerContainerWorker) -> None: + # GIVEN + worker._container_id = "container-id" + cmd = 'echo "Hello world"' + mock_run_result = MagicMock() + mock_run_result.returncode = 0 + mock_run_result.stdout = "Hello world" + mock_run_result.stderr = None + + with patch.object(mod.subprocess, "run", return_value=mock_run_result) as mock_run: + # WHEN + result = worker.send_command(cmd) + + # THEN + assert result.exit_code == 0 + assert result.stdout == "Hello world" + assert result.stderr is None + mock_run.assert_called_once_with( + args=[ + "docker", + "exec", + worker.container_id, + "/bin/bash", + "-euo", + "pipefail", + "-c", + cmd, + ], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + encoding="utf-8", + ) + + def test_worker_id(self, worker: DockerContainerWorker) -> None: + # GIVEN + worker._container_id = "container-id" + worker_id = "worker-3ff2c8b6c6a5452f8f7b85cd45b80ba3" + send_command_result = CommandResult(0, f"{worker_id}\r\n") + + with patch.object(worker, "send_command", return_value=send_command_result): + # WHEN + result = worker.worker_id + + # THEN + assert result == worker_id diff --git a/test/unit/shared_constants.py b/test/unit/shared_constants.py deleted file mode 100644 index 94f498d..0000000 --- a/test/unit/shared_constants.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - -MOCK_FARM_ID = "farm-0123456789abcdefabcdefabcdefabcd" -MOCK_FARM_NAME = "fake_farm_name" -MOCK_FLEET_ID = "fleet-0123456789abcdefabcdefabcdefabcd" -MOCK_FLEET_NAME = "fake_fleet_name" -MOCK_QUEUE_ID = "queue-0123456789abcdefabcdefabcdefabcd" -MOCK_QUEUE_NAME = "fake_queue_name" -MOCK_WORKER_ROLE_ARN = "fake_worker_role_arn" -MOCK_JOB_ATTACHMENTS_BUCKET_NAME = "fake_job_attachments_bucket_name" - -MOCK_DEFAULT_CMF_CONFIG = { - "customerManaged": { - "autoScalingConfiguration": { - "mode": "NO_SCALING", - "maxFleetSize": 1, - }, - "workerRequirements": { - "vCpuCount": {"min": 1}, - "memoryMiB": {"min": 1024}, - "osFamily": "linux", - "cpuArchitectureType": "x86_64", - }, - } -} diff --git a/test/unit/test_deadline_manager.py b/test/unit/test_deadline_manager.py deleted file mode 100644 index 3c860f4..0000000 --- a/test/unit/test_deadline_manager.py +++ /dev/null @@ -1,571 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - -from __future__ import annotations - -import os -from typing import Any -from unittest import mock - -import pytest -from botocore.exceptions import ClientError - -from deadline_test_scaffolding import DeadlineManager - -from shared_constants import ( - MOCK_FARM_ID, - MOCK_FARM_NAME, - MOCK_FLEET_ID, - MOCK_FLEET_NAME, - MOCK_QUEUE_ID, - MOCK_QUEUE_NAME, - MOCK_DEFAULT_CMF_CONFIG, -) - - -class TestDeadlineManager: - @pytest.fixture(autouse=True) - def setup_test(self, mock_get_deadline_models): - pass - - @pytest.fixture(scope="function") - def mock_deadline_manager(self) -> DeadlineManager: - """ - Returns a DeadlineManager where any boto3 clients are mocked, including - the deadline_client that is part of the DeadlineManager. - """ - with mock.patch.object(DeadlineManager, "_get_deadline_client"), mock.patch( - "deadline_test_scaffolding.deadline_manager.boto3.client" - ): - return DeadlineManager() - - ids = [ - pytest.param(None, None, None, None, id="NoKMSKey"), - pytest.param({"KeyId": "FakeKMSKeyID"}, None, None, None, id="KMSKeyNoFarm"), - pytest.param({"KeyId": "FakeKMSKeyID"}, MOCK_FARM_ID, None, None, id="KMSKeyFarmNoFleet"), - pytest.param( - {"KeyId": "FakeKMSKeyID"}, - MOCK_FARM_ID, - MOCK_FLEET_ID, - None, - id="KMSKeyFarmFleetNoQueue", - ), - pytest.param( - {"KeyId": "FakeKMSKeyID"}, - MOCK_FARM_ID, - MOCK_FLEET_ID, - MOCK_QUEUE_ID, - id="KMSKeyFarmFleetQueue", - ), - pytest.param( - {"KeyId": "FakeKMSKeyID"}, - MOCK_FARM_ID, - None, - MOCK_QUEUE_ID, - id="KMSKeyFarmQueueNoFleet", - ), - ] - - @mock.patch.object(DeadlineManager, "create_fleet") - @mock.patch.object(DeadlineManager, "create_queue") - @mock.patch.object(DeadlineManager, "create_farm") - @mock.patch.object(DeadlineManager, "create_kms_key") - @mock.patch.object(DeadlineManager, "add_job_attachments_bucket") - @mock.patch.object(DeadlineManager, "queue_fleet_association") - def test_create_scaffolding( - self, - mocked_create_kms_key: mock.Mock, - mocked_create_farm: mock.Mock, - mocked_create_queue: mock.Mock, - mocked_create_fleet: mock.Mock, - mocked_queue_fleet_association: mock.Mock, - mocked_add_job_attachments_bucket: mock.Mock, - mock_deadline_manager: DeadlineManager, - ) -> None: - # GIVEN - mock_deadline_manager.farm_id = MOCK_FARM_ID - mock_deadline_manager.fleet_id = MOCK_FLEET_ID - mock_deadline_manager.queue_id = MOCK_QUEUE_ID - worker_role_arn = "fake_worker_role" - job_attachments_bucket = "fake_job_attachments_bucket" - - # WHEN - mock_deadline_manager.create_scaffolding(worker_role_arn, job_attachments_bucket) - - mocked_create_kms_key.assert_called_once() - mocked_create_farm.assert_called_once() - mocked_create_queue.assert_called_once() - mocked_add_job_attachments_bucket.assert_called_once() - mocked_create_fleet.assert_called_once() - mocked_queue_fleet_association.assert_called_once() - - @mock.patch.object(DeadlineManager, "delete_fleet") - @mock.patch.object(DeadlineManager, "delete_queue") - @mock.patch.object(DeadlineManager, "delete_farm") - @mock.patch.object(DeadlineManager, "delete_kms_key") - @pytest.mark.parametrize("kms_key_metadata, farm_id, fleet_id, queue_id", ids) - def test_cleanup_scaffolding( - self, - mocked_delete_kms_key: mock.Mock, - mocked_delete_farm: mock.Mock, - mocked_delete_queue: mock.Mock, - mocked_delete_fleet: mock.Mock, - kms_key_metadata: dict[str, Any] | None, - farm_id: str | None, - fleet_id: str | None, - queue_id: str | None, - mock_deadline_manager: DeadlineManager, - ) -> None: - # GIVEN - mock_deadline_manager.kms_key_metadata = kms_key_metadata - mock_deadline_manager.farm_id = farm_id - mock_deadline_manager.fleet_id = fleet_id - mock_deadline_manager.queue_id = queue_id - - # WHEN - mock_deadline_manager.cleanup_scaffolding() - - # c - if fleet_id: - mocked_delete_fleet.assert_called_once() - - if queue_id: - mocked_delete_queue.assert_called_once() - - if farm_id: - mocked_delete_farm.assert_called_once() - - if kms_key_metadata: - mocked_delete_kms_key.assert_called_once() - - def test_create_kms_key(self, mock_deadline_manager: DeadlineManager) -> None: - # GIVEN - fake_kms_metadata = {"KeyMetadata": {"KeyId": "Foo"}} - mock_deadline_manager.kms_client.create_key.return_value = fake_kms_metadata - - # WHEN - mock_deadline_manager.create_kms_key() - - # THEN - mock_deadline_manager.kms_client.create_key.assert_called_once_with( - Description="The KMS used for testing created by the " - "DeadlineClientSoftwareTestScaffolding.", - Tags=[{"TagKey": "Name", "TagValue": "DeadlineClientSoftwareTestScaffolding"}], - ) - - assert mock_deadline_manager.kms_key_metadata == fake_kms_metadata["KeyMetadata"] - - mock_deadline_manager.kms_client.enable_key.assert_called_once_with( - KeyId=fake_kms_metadata["KeyMetadata"]["KeyId"] - ) - - def test_delete_kms_key(self, mock_deadline_manager: DeadlineManager) -> None: - # GIVEN - fake_kms_metadata = {"KeyId": "Foo"} - mock_deadline_manager.kms_key_metadata = fake_kms_metadata - - # WHEN - mock_deadline_manager.delete_kms_key() - - # THEN - mock_deadline_manager.kms_client.schedule_key_deletion.assert_called_once_with( - KeyId=fake_kms_metadata["KeyId"], PendingWindowInDays=7 - ) - - assert mock_deadline_manager.kms_key_metadata is None - - key_metadatas = [ - pytest.param(None, id="NoMetadata"), - pytest.param({"Foo": "Bar"}, id="NoKeyInMetadata"), - ] - - @pytest.mark.parametrize("key_metadatas", key_metadatas) - def test_delete_kms_key_no_key( - self, - key_metadatas: dict[str, Any] | None, - mock_deadline_manager: DeadlineManager, - ) -> None: - # GIVEN - mock_deadline_manager.kms_key_metadata = key_metadatas - - # WHEN / THEN - with pytest.raises(Exception): - mock_deadline_manager.delete_kms_key() - - assert not mock_deadline_manager.kms_client.schedule_key_deletion.called - - def test_create_farm(self, mock_deadline_manager: DeadlineManager) -> None: - # GIVEN - fake_kms_metadata = {"Arn": "fake_kms_arn"} - - mock_deadline_manager.kms_key_metadata = fake_kms_metadata - mock_deadline_manager.deadline_client.create_farm.return_value = {"farmId": MOCK_FARM_ID} # type: ignore[attr-defined] - - # WHEN - mock_deadline_manager.create_farm(MOCK_FARM_NAME) - - # THEN - mock_deadline_manager.deadline_client.create_farm.assert_called_once_with( # type: ignore[attr-defined] # noqa - displayName=MOCK_FARM_NAME, kmsKeyArn=fake_kms_metadata["Arn"] - ) - assert mock_deadline_manager.farm_id == MOCK_FARM_ID - - key_metadatas = [ - pytest.param(None, id="NoMetadata"), - pytest.param({"Foo": "Bar"}, id="NoKeyInMetadata"), - ] - - @pytest.mark.parametrize("key_metadatas", key_metadatas) - def test_create_farm_kms_not_valid( - self, - key_metadatas: dict[str, Any] | None, - mock_deadline_manager: DeadlineManager, - ) -> None: - # GIVEN - mock_deadline_manager.kms_key_metadata = key_metadatas - - # WHEN / THEN - with pytest.raises(Exception): - mock_deadline_manager.create_farm(MOCK_FARM_NAME) - - assert not mock_deadline_manager.deadline_client.create_farm.called # type: ignore[attr-defined] # noqa - assert mock_deadline_manager.farm_id is None - - def test_delete_farm(self, mock_deadline_manager: DeadlineManager) -> None: - # GIVEN - mock_deadline_manager.farm_id = MOCK_FARM_ID - - # WHEN - mock_deadline_manager.delete_farm() - - # THEN - mock_deadline_manager.deadline_client.delete_farm.assert_called_once_with( - farmId=MOCK_FARM_ID - ) - - assert mock_deadline_manager.farm_id is None - - def test_delete_farm_not_created(self, mock_deadline_manager: DeadlineManager) -> None: - # GIVEN - # mock_deadline_manager fixture - # WHEN / THEN - with pytest.raises(Exception): - mock_deadline_manager.delete_farm() - - # THEN - assert not mock_deadline_manager.deadline_client.delete_farm.called - - def test_create_queue(self, mock_deadline_manager: DeadlineManager) -> None: - # GIVEN - mock_deadline_manager.farm_id = MOCK_FARM_ID - mock_deadline_manager.deadline_client.create_queue.return_value = {"queueId": MOCK_QUEUE_ID} # type: ignore[attr-defined] - - # WHEN - mock_deadline_manager.create_queue(MOCK_QUEUE_NAME) - - # THEN - mock_deadline_manager.deadline_client.create_queue.assert_called_once_with( # type: ignore[attr-defined] - displayName=MOCK_QUEUE_NAME, - farmId=MOCK_FARM_ID, - ) - - assert mock_deadline_manager.queue_id == MOCK_QUEUE_ID - - def test_create_queue_no_farm(self, mock_deadline_manager: DeadlineManager) -> None: - # GIVEN - mock_deadline_manager.deadline_client.create_queue.return_value = {"queueId": MOCK_QUEUE_ID} # type: ignore[attr-defined] - - # WHEN - with pytest.raises(Exception): - mock_deadline_manager.create_queue(MOCK_QUEUE_NAME) - - # THEN - assert not mock_deadline_manager.deadline_client.create_queue.called # type: ignore[attr-defined] - - assert mock_deadline_manager.queue_id is None - - def test_delete_queue(self, mock_deadline_manager: DeadlineManager) -> None: - # GIVEN - mock_deadline_manager.queue_id = MOCK_QUEUE_ID - mock_deadline_manager.farm_id = MOCK_FARM_ID - - # WHEN - mock_deadline_manager.delete_queue() - - # THEN - mock_deadline_manager.deadline_client.delete_queue.assert_called_once_with( - queueId=MOCK_QUEUE_ID, farmId=MOCK_FARM_ID - ) - - assert mock_deadline_manager.queue_id is None - - farm_queue_ids = [ - pytest.param(MOCK_QUEUE_ID, None, id="NoFarmId"), - pytest.param(None, MOCK_FARM_ID, id="NoQueueId"), - ] - - @pytest.mark.parametrize("fake_queue_id, fake_farm_id", farm_queue_ids) - def test_delete_queue_no_farm_queue( - self, - fake_queue_id: str | None, - fake_farm_id: str | None, - mock_deadline_manager: DeadlineManager, - ) -> None: - # GIVEN - mock_deadline_manager.queue_id = fake_queue_id - mock_deadline_manager.farm_id = fake_farm_id - - # WHEN / THEN - with pytest.raises(Exception): - mock_deadline_manager.delete_queue() - - assert not mock_deadline_manager.deadline_client.delete_queue.called - - def test_create_fleet(self, mock_deadline_manager: DeadlineManager) -> None: - # GIVEN - mock_deadline_manager.farm_id = MOCK_FARM_ID - fake_worker_role_arn = "fake_worker_role_arn" - mock_deadline_manager.deadline_client.create_fleet.return_value = {"fleetId": MOCK_FLEET_ID} # type: ignore[attr-defined] - mock_deadline_manager.deadline_client.get_fleet.return_value = {"status": "ACTIVE"} # type: ignore[attr-defined] - - # WHEN - mock_deadline_manager.create_fleet(MOCK_FLEET_NAME, fake_worker_role_arn) - - # THEN - mock_deadline_manager.deadline_client.create_fleet.assert_called_once_with( # type: ignore[attr-defined] - farmId=MOCK_FARM_ID, - displayName=MOCK_FLEET_NAME, - roleArn=fake_worker_role_arn, - configuration=MOCK_DEFAULT_CMF_CONFIG, - ) - - assert mock_deadline_manager.fleet_id == MOCK_FLEET_ID - - def test_create_fleet_no_farm(self, mock_deadline_manager: DeadlineManager) -> None: - # GIVEN - # mock_deadline_manager fixture - worker_role_arn = "fake_worker_role_arn" - - # WHEN / THEN - with pytest.raises(Exception): - mock_deadline_manager.create_fleet(MOCK_FLEET_NAME, worker_role_arn) - - assert not mock_deadline_manager.deadline_client.create_fleet.called # type: ignore[attr-defined] - assert mock_deadline_manager.fleet_id is None - - def test_delete_fleet(self, mock_deadline_manager: DeadlineManager) -> None: - # GIVEN - mock_deadline_manager.farm_id = MOCK_FARM_ID - mock_deadline_manager.fleet_id = MOCK_FLEET_ID - mock_deadline_manager.deadline_client.get_queue_fleet_association.return_value = {"status": "STOPPED"} # type: ignore[attr-defined] - mock_deadline_manager.deadline_client.get_fleet.return_value = {"status": "DELETED"} # type: ignore[attr-defined] - - # WHEN - mock_deadline_manager.delete_fleet() - - # THEN - mock_deadline_manager.deadline_client.delete_fleet.assert_called_once_with( - farmId=MOCK_FARM_ID, fleetId=MOCK_FLEET_ID - ) - - assert mock_deadline_manager.fleet_id is None - - farm_queue_ids = [ - pytest.param(MOCK_FARM_ID, None, id="NoFleetId"), - pytest.param(None, MOCK_FLEET_ID, id="NoFarmId"), - ] - - # Create a test for test_delete_fleet - - @pytest.mark.parametrize("fake_farm_id, fake_fleet_id", farm_queue_ids) - def test_delete_fleet_no_farm_fleet( - self, - fake_farm_id: str | None, - fake_fleet_id: str | None, - mock_deadline_manager: DeadlineManager, - ) -> None: - # GIVEN - mock_deadline_manager.farm_id = fake_farm_id - mock_deadline_manager.fleet_id = fake_fleet_id - - # WHEN / THEN - with pytest.raises(Exception): - mock_deadline_manager.delete_fleet() - - farm_queue_ids = [ - pytest.param( - "kms_client", - {}, - "create_key", - "create_kms_key", - [], - "kms_key_metadata", - id="FailedCreateKMSKey", - ), - pytest.param( - "kms_client", - {"kms_key_metadata": {"KeyId": "TestKeyId"}}, - "schedule_key_deletion", - "delete_kms_key", - [], - None, - id="FailedDeleteKMSKey", - ), - pytest.param( - "deadline_client", - {"kms_key_metadata": {"Arn": "TestArn"}}, - "create_farm", - "create_farm", - ["TestFarm"], - "farm_id", - id="FailedCreateFarm", - ), - pytest.param( - "deadline_client", - {"farm_id": "fake_farm_id"}, - "delete_farm", - "delete_farm", - [], - None, - id="FailedDeleteFarm", - ), - pytest.param( - "deadline_client", - {"farm_id": "fake_farm_id"}, - "create_queue", - "create_queue", - ["TestQueue"], - "queue_id", - id="FailedCreateQueue", - ), - pytest.param( - "deadline_client", - {"farm_id": "fake_farm_id", "queue_id": "fake_queue_id"}, - "delete_queue", - "delete_queue", - [], - None, - id="FailedDeleteQueue", - ), - pytest.param( - "deadline_client", - {"farm_id": "fake_farm_id", "worker_role_arn": "fake_worker_role_arn"}, - "create_fleet", - "create_fleet", - ["TestFleet", "fake_worker_arn"], - "fleet_id", - id="FailedCreateFleet", - ), - pytest.param( - "deadline_client", - {"farm_id": "fake_farm_id", "fleet_id": "fake_fleet_id"}, - "get_queue_fleet_association", # This is the first boto call in delete fleet - "delete_fleet", - [], - None, - id="FailedDeleteFleet", - ), - ] - - @mock.patch("deadline_test_scaffolding.deadline_manager.boto3.Session") - @mock.patch("deadline_test_scaffolding.deadline_manager.boto3.client") - @pytest.mark.parametrize( - "client, bm_properties, client_function_name, manager_function_name, args," - "expected_parameter", - farm_queue_ids, - ) - def test_failure_with_boto( - self, - _: mock.Mock, - mocked_boto_session: mock.MagicMock, - client: str, - bm_properties: dict[str, Any], - client_function_name: str, - manager_function_name: str, - args: list[Any], - expected_parameter: str, - ) -> None: - """This test will confirm that when a ClientError is raised when we use the boto3 - clients for deadline and kms - - Args: - _ (mock.Mock): _description_ - client (str): _description_ - bm_properties (dict[str, Any]): _description_ - client_function_name (str): _description_ - manager_function_name (str): _description_ - args (list[Any]): _description_ - expected_parameter (str): _description_ - """ - - # GIVEN - mocked_function = mock.Mock( - side_effect=ClientError( - { - "Error": { - "Code": "TestException", - "Message": "This is a test exception to simulate an exception being " - "raised.", - } - }, - "TestException", - ) - ) - mocked_client = mock.Mock() - setattr(mocked_client, client_function_name, mocked_function) - - bm = DeadlineManager() - setattr(bm, client, mocked_client) - - for property, value in bm_properties.items(): - setattr(bm, property, value) - - # WHEN - with pytest.raises(ClientError): - manager_function = getattr(bm, manager_function_name) - manager_function(*args) - - # THEN - if expected_parameter: - assert getattr(bm, expected_parameter) is None - - -class TestDeadlineManagerAddModels: - """This class is here because the tests above are mocking out the add_deadline_models method - using a fixture.""" - - @mock.patch.dict(os.environ, {"DEADLINE_SERVICE_MODEL_BUCKET": "test-bucket"}) - @mock.patch("os.makedirs") - @mock.patch("tempfile.TemporaryDirectory") - @mock.patch("deadline_test_scaffolding.deadline_manager.boto3.Session") - @mock.patch("deadline_test_scaffolding.deadline_manager.boto3.client") - def test_get_deadline_models( - self, - mocked_boto_client: mock.MagicMock, - mocked_boto_session: mock.MagicMock, - mocked_temp_dir: mock.MagicMock, - mocked_mkdir: mock.MagicMock, - ): - # GIVEN - temp_path = "/tmp/test" - mocked_temp_dir.return_value.name = temp_path - deadline_endpoint = os.getenv("DEADLINE_ENDPOINT") - - # WHEN - manager = DeadlineManager(should_add_deadline_models=True) - - # THEN - mocked_boto_client.assert_any_call("s3") - mocked_temp_dir.assert_called_once() - mocked_mkdir.assert_called_once_with( - f"{temp_path}/deadline/{DeadlineManager.MOCKED_SERVICE_VERSION}" - ) - mocked_boto_client.return_value.download_file.assert_called_with( - "test-bucket", - "service-2.json", - f"{temp_path}/deadline/{DeadlineManager.MOCKED_SERVICE_VERSION}/service-2.json", - ) - mocked_boto_session.return_value.client.assert_called_with( - "deadline", endpoint_url=deadline_endpoint - ) - assert manager.deadline_model_dir is not None - assert manager.deadline_model_dir.name == temp_path diff --git a/test/unit/test_job_attachment_manager.py b/test/unit/test_job_attachment_manager.py index 2289a81..8708780 100644 --- a/test/unit/test_job_attachment_manager.py +++ b/test/unit/test_job_attachment_manager.py @@ -1,12 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -from datetime import datetime -from unittest import mock +from typing import Generator +from unittest.mock import MagicMock, patch import pytest from botocore.exceptions import ClientError, WaiterError -from botocore.stub import ANY, Stubber +from moto import mock_s3 +from deadline_test_scaffolding import job_attachment_manager as jam_module from deadline_test_scaffolding import JobAttachmentManager @@ -16,359 +17,168 @@ class TestJobAttachmentManager: """ @pytest.fixture(autouse=True) - def setup_test(self, mock_get_deadline_models, boto_config): - with mock.patch("deadline_test_scaffolding.job_attachment_manager.DeadlineManager"): - self.job_attachment_manager = JobAttachmentManager( - stage="test", account_id="123456789101" - ) - yield - - def test_deploy_resources(self): - """ - Test that during the normal flow that the upgrade stack boto call is made. - """ - # WHEN - with Stubber(self.job_attachment_manager.stack.meta.client) as stubber: - stubber.add_response( - "update_stack", - { - "StackId": "arn:aws:cloudformation:us-west-2:123456789101:stack/" - "JobAttachmentIntegTest/abcdefgh-1234-ijkl-5678-mnopqrstuvwx" - }, - expected_params={ - "StackName": "JobAttachmentIntegTest", - "TemplateBody": ANY, - "Parameters": [ - { - "ParameterKey": "BucketName", - "ParameterValue": "job-attachment-integ-test-test-123456789101", - }, - ], - }, - ) - stubber.add_response( - "describe_stacks", - { - "Stacks": [ - { - "StackName": "JobAttachmentIntegTest", - "CreationTime": datetime(2015, 1, 1), - "StackStatus": "UPDATE_COMPLETE", - }, - ], - }, - expected_params={ - "StackName": "JobAttachmentIntegTest", - }, - ) - self.job_attachment_manager.deploy_resources() - - stubber.assert_no_pending_responses() - - @mock.patch( - "deadline_test_scaffolding.job_attachment_manager." "JobAttachmentManager.cleanup_resources" - ) - def test_deploy_resources_client_error( - self, - mocked_cleanup_resources: mock.MagicMock, - ): - """ - Test that if there's an issue deploying resources, the rest get cleaned up. - """ - # WHEN - with mock.patch.object( - self.job_attachment_manager.deadline_manager, - "create_kms_key", - side_effect=ClientError( - {"ErrorCode": "Oops", "Message": "Something went wrong"}, "create_kms_key" - ), - ), pytest.raises(ClientError): - self.job_attachment_manager.deploy_resources() + def mock_farm_cls(self) -> Generator[MagicMock, None, None]: + with patch.object(jam_module, "Farm") as mock: + yield mock - mocked_cleanup_resources.assert_called_once() + @pytest.fixture(autouse=True) + def mock_queue_cls(self) -> Generator[MagicMock, None, None]: + with patch.object(jam_module, "Queue") as mock: + yield mock - @mock.patch( - "deadline_test_scaffolding.job_attachment_manager." "JobAttachmentManager.cleanup_resources" - ) - def test_deploy_resources_waiter_error( - self, - mocked_cleanup_resources: mock.MagicMock, - ): - """ - Test that if there's an issue deploying resources, the rest get cleaned up. - But this time with a waiter error. - """ - # WHEN - with Stubber(self.job_attachment_manager.stack.meta.client) as stubber, pytest.raises( - WaiterError + @pytest.fixture(autouse=True) + def mock_stack(self) -> Generator[MagicMock, None, None]: + with patch.object(jam_module, "JobAttachmentsBootstrapStack") as mock: + yield mock.return_value + + @pytest.fixture + def job_attachment_manager(self) -> Generator[JobAttachmentManager, None, None]: + with mock_s3(): + yield JobAttachmentManager( + stage="test", + account_id="123456789101", + ) + + class TestDeployResources: + def test_deploys_all_resources( + self, + job_attachment_manager: JobAttachmentManager, + mock_farm_cls: MagicMock, + mock_queue_cls: MagicMock, + mock_stack: MagicMock, ): - stubber.add_response( - "update_stack", - { - "StackId": "arn:aws:cloudformation:us-west-2:123456789101:stack/" - "JobAttachmentIntegTest/abcdefgh-1234-ijkl-5678-mnopqrstuvwx" - }, - expected_params={ - "StackName": "JobAttachmentIntegTest", - "TemplateBody": ANY, - "Parameters": [ - { - "ParameterKey": "BucketName", - "ParameterValue": "job-attachment-integ-test-test-123456789101", - }, - ], - }, - ) - stubber.add_client_error( - "describe_stacks", service_error_code="400", service_message="Oops" - ) - - self.job_attachment_manager.deploy_resources() - - stubber.assert_no_pending_responses() - - mocked_cleanup_resources.assert_called_once() - - def test_deploy_stack_update_while_create_in_progress(self): - """ - Test that if an attempt to update a stack when a stack is in the process of being created, - we wait for the stack to complete being created. - """ - # WHEN - with Stubber(self.job_attachment_manager.stack.meta.client) as stubber: - stubber.add_client_error( - "update_stack", - service_error_code="400", - service_message="JobAttachmentIntegTest is in CREATE_IN_PROGRESS " - "state and can not be updated.", - ) - stubber.add_response( - "describe_stacks", - { - "Stacks": [ - { - "StackName": "JobAttachmentIntegTest", - "CreationTime": datetime(2015, 1, 1), - "StackStatus": "CREATE_COMPLETE", - }, - ], - }, - expected_params={ - "StackName": "JobAttachmentIntegTest", - }, - ) - - self.job_attachment_manager.deploy_stack() - - stubber.assert_no_pending_responses() - - @mock.patch( - "deadline_test_scaffolding.job_attachment_manager." "JobAttachmentManager._create_stack" - ) - def test_deploy_stack_update_while_stack_doesnt_need_updating( - self, mocked__create_stack: mock.MagicMock - ): - """ - Test that if a stack already exists that doesn't need updating, nothing happens. - """ - # WHEN - with Stubber(self.job_attachment_manager.stack.meta.client) as stubber: - stubber.add_client_error( - "update_stack", - service_error_code="400", - service_message="No updates are to be performed.", - ) - - self.job_attachment_manager.deploy_stack() - - stubber.assert_no_pending_responses() - - mocked__create_stack.assert_not_called() - - def test_deploy_stack_stack_doesnt_exist(self): - """ - Test that if when updating the stack, that it gets created if it doesn't exist. - """ - # WHEN - with Stubber(self.job_attachment_manager.stack.meta.client) as stubber: - stubber.add_client_error( - "update_stack", - service_error_code="400", - service_message="The Stack JobAttachmentIntegTest doesn't exist", - ) - stubber.add_response( - "create_stack", - { - "StackId": "arn:aws:cloudformation:us-west-2:123456789101:stack/" - "JobAttachmentIntegTest/abcdefgh-1234-ijkl-5678-mnopqrstuvwx" - }, - expected_params={ - "StackName": "JobAttachmentIntegTest", - "TemplateBody": ANY, - "OnFailure": "DELETE", - "EnableTerminationProtection": False, - "Parameters": [ - { - "ParameterKey": "BucketName", - "ParameterValue": "job-attachment-integ-test-test-123456789101", - }, - ], - }, - ) - stubber.add_response( - "describe_stacks", - { - "Stacks": [ - { - "StackName": "JobAttachmentIntegTest", - "CreationTime": datetime(2015, 1, 1), - "StackStatus": "CREATE_COMPLETE", - }, - ], - }, - expected_params={ - "StackName": "JobAttachmentIntegTest", - }, - ) - - self.job_attachment_manager.deploy_stack() - - stubber.assert_no_pending_responses() - - def test_deploy_stack_stack_already_exists(self): - """ - Test the if we try to create a stack when it already exists, - we wait for it to finish being created. - """ - # WHEN - with Stubber(self.job_attachment_manager.stack.meta.client) as stubber: - stubber.add_client_error( - "update_stack", - service_error_code="400", - service_message="The Stack JobAttachmentIntegTest doesn't exist", - ) - stubber.add_client_error( - "create_stack", - service_error_code="400", - service_message="Stack [JobAttachmentIntegTest] already exists", - ) - stubber.add_response( - "describe_stacks", - { - "Stacks": [ - { - "StackName": "JobAttachmentIntegTest", - "CreationTime": datetime(2015, 1, 1), - "StackStatus": "CREATE_COMPLETE", - }, - ], - }, - expected_params={ - "StackName": "JobAttachmentIntegTest", - }, - ) - - self.job_attachment_manager.deploy_stack() - - stubber.assert_no_pending_responses() - - def test_deploy_stack_other_client_error(self): - """ - Test that when we create a stack, unhandled client errors get raised. - """ - # WHEN - with Stubber(self.job_attachment_manager.stack.meta.client) as stubber, pytest.raises( - ClientError + """ + Tests that all resources are created when deploy_resources is called + """ + # WHEN + job_attachment_manager.deploy_resources() + + # THEN + mock_farm_cls.create.assert_called_once() + mock_queue_cls.create.assert_called_once() + mock_stack.deploy.assert_called_once() + + @pytest.mark.parametrize( + "error", + [ + ClientError({}, None), + WaiterError(None, None, None), + ], + ) + def test_cleans_up_when_error_is_raised( + self, + error: Exception, + job_attachment_manager: JobAttachmentManager, + mock_farm_cls: MagicMock, + mock_queue_cls: MagicMock, + mock_stack: MagicMock, ): - stubber.add_client_error( - "update_stack", - service_error_code="400", - service_message="The Stack JobAttachmentIntegTest doesn't exist", - ) - stubber.add_client_error( - "create_stack", - service_error_code="400", - service_message="Oops", - ) - - self.job_attachment_manager.deploy_stack() - - stubber.assert_no_pending_responses() - - def test_empty_bucket_bucket_doesnt_exist(self): - """ - If we try to empty a bucket that doesn't exist, make sure nothing happens. - """ - # WHEN - with Stubber(self.job_attachment_manager.bucket.meta.client) as stubber: - stubber.add_client_error( - "list_objects", - service_error_code="400", - service_message="The specified bucket does not exist", - ) - - self.job_attachment_manager.empty_bucket() - - stubber.assert_no_pending_responses() - - def test_empty_bucket_any_other_error(self): - """ - Test that unhandled client errors during bucket creation are raised. - """ - # WHEN - with Stubber(self.job_attachment_manager.bucket.meta.client) as stubber, pytest.raises( - ClientError + """ + Test that if there's an issue deploying resources, the rest get cleaned up. + """ + # GIVEN + possible_failures: list[MagicMock] = [ + mock_farm_cls.create, + mock_queue_cls.create, + mock_stack.deploy, + ] + for possible_failure in possible_failures: + possible_failure.side_effect = error + + with ( + patch.object( + job_attachment_manager, + "cleanup_resources", + wraps=job_attachment_manager.cleanup_resources, + ) as spy_cleanup_resources, + pytest.raises(type(error)) as raised_exc, + ): + # WHEN + job_attachment_manager.deploy_resources() + + # THEN + assert raised_exc.value is error + spy_cleanup_resources.assert_called_once() + + class TestEmptyBucket: + def test_deletes_all_objects(self, job_attachment_manager: JobAttachmentManager): + # GIVEN + bucket = job_attachment_manager.bucket + bucket.create() + bucket.put_object(Key="test-object", Body="Hello world".encode()) + bucket.put_object(Key="test-object-2", Body="Hello world 2".encode()) + assert len(list(bucket.objects.all())) == 2 + + # WHEN + job_attachment_manager.empty_bucket() + + # THEN + assert len(list(bucket.objects.all())) == 0 + + def test_swallows_bucket_doesnt_exist_error( + self, job_attachment_manager: JobAttachmentManager ): - stubber.add_client_error( - "list_objects", - service_error_code="400", - service_message="Ooops", - ) - - self.job_attachment_manager.empty_bucket() - - stubber.assert_no_pending_responses() - - def test_cleanup_resources(self): + """ + If we try to empty a bucket that doesn't exist, make sure nothing happens. + """ + # GIVEN + # The bucket does not exist (we do not create it) + + try: + # WHEN + job_attachment_manager.empty_bucket() + except ClientError as e: + pytest.fail( + f"JobAttachmentManager.empty_bucket raised an error when it shouldn't have: {e}" + ) + else: + # THEN + # Success + pass + + def test_raises_any_other_error( + self, + job_attachment_manager: JobAttachmentManager, + ): + """ + Test that unhandled client errors during bucket creation are raised. + """ + # GIVEN + exc = ClientError({"Error": {"Message": "test"}}, "test-operation") + with ( + patch.object(job_attachment_manager, "bucket") as mock_bucket, + pytest.raises(ClientError) as raised_exc, + ): + mock_bucket.objects.all.side_effect = exc + + # WHEN + job_attachment_manager.empty_bucket() + + # THEN + assert raised_exc.value is exc + mock_bucket.objects.all.assert_called_once() + + def test_cleanup_resources( + self, + job_attachment_manager: JobAttachmentManager, + mock_farm_cls: MagicMock, + mock_queue_cls: MagicMock, + mock_stack: MagicMock, + ): """ Test that all resources get cleaned up when they exist. """ - self.job_attachment_manager.deadline_manager.farm_id = "farm-asdf" - self.job_attachment_manager.deadline_manager.kms_key_metadata = {"key_id": "aasdfkj"} - self.job_attachment_manager.deadline_manager.queue_id = "queue-asdfji" - - # WHEN - with Stubber(self.job_attachment_manager.bucket.meta.client) as stubber: - stubber.add_response( - "list_objects", - { - "Contents": [], - }, - ) - self.job_attachment_manager.cleanup_resources() - - stubber.assert_no_pending_responses() - - self.job_attachment_manager.deadline_manager.delete_farm.assert_called_once() # type: ignore[attr-defined] # noqa - self.job_attachment_manager.deadline_manager.delete_kms_key.assert_called_once() # type: ignore[attr-defined] # noqa - self.job_attachment_manager.deadline_manager.delete_queue.assert_called_once() # type: ignore[attr-defined] # noqa - - def test_cleanup_resources_no_resource_exist(self): - """ - Test that no deletion calls are made when resources don't exist. - """ - # WHEN - with Stubber(self.job_attachment_manager.bucket.meta.client) as stubber: - stubber.add_response( - "list_objects", - { - "Contents": [], - }, - ) - self.job_attachment_manager.cleanup_resources() + # GIVEN + job_attachment_manager.deploy_resources() - stubber.assert_no_pending_responses() + with patch.object( + job_attachment_manager, "empty_bucket", wraps=job_attachment_manager.empty_bucket + ) as spy_empty_bucket: + # WHEN + job_attachment_manager.cleanup_resources() - self.job_attachment_manager.deadline_manager.create_farm.assert_not_called() # type: ignore[attr-defined] # noqa - self.job_attachment_manager.deadline_manager.create_kms_key.assert_not_called() # type: ignore[attr-defined] # noqa - self.job_attachment_manager.deadline_manager.create_queue.assert_not_called() # type: ignore[attr-defined] # noqa + # THEN + spy_empty_bucket.assert_called_once() + mock_stack.destroy.assert_called_once() + mock_queue_cls.create.return_value.delete.assert_called_once() + mock_farm_cls.create.return_value.delete.assert_called_once()