diff --git a/src/deadline_test_fixtures/deadline/worker.py b/src/deadline_test_fixtures/deadline/worker.py index fe7b4c3..07eaac9 100644 --- a/src/deadline_test_fixtures/deadline/worker.py +++ b/src/deadline_test_fixtures/deadline/worker.py @@ -18,7 +18,6 @@ from dataclasses import dataclass, field, InitVar, replace from typing import Any, ClassVar, Optional, cast -from .client import DeadlineClient from ..models import ( PipInstall, PosixSessionUser, @@ -56,7 +55,10 @@ def configure_worker_command(*, config: DeadlineWorkerConfiguration) -> str: # if config.service_model: cmds.append( - f"runuser -l {config.user} -s /bin/bash -c '{config.service_model.install_command}'" + f"runuser -l {config.user} -s /bin/bash -c '{' && '.join(config.service_model.install_commands)}'" + ) + cmds.append( + f'echo "AWS_ENDPOINT_URL_DEADLINE=\\"{config.service_model.deadline_endpoint_url_fmt_str.format(config.region)}\\"" >> /etc/environment' ) return " && ".join(cmds) @@ -146,7 +148,6 @@ class EC2InstanceWorker(DeadlineWorker): s3_client: botocore.client.BaseClient ec2_client: botocore.client.BaseClient ssm_client: botocore.client.BaseClient - deadline_client: DeadlineClient configuration: DeadlineWorkerConfiguration instance_id: Optional[str] = field(init=False, default=None) diff --git a/src/deadline_test_fixtures/example_config.sh b/src/deadline_test_fixtures/example_config.sh index e91edfc..c43463c 100644 --- a/src/deadline_test_fixtures/example_config.sh +++ b/src/deadline_test_fixtures/example_config.sh @@ -30,6 +30,10 @@ export CODEARTIFACT_REGION # --- OPTIONAL --- # +# The AWS region to use +# Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2 +export REGION + # Extra local path for boto to look for AWS models in # Does not apply to the worker export AWS_DATA_PATH @@ -38,9 +42,10 @@ export AWS_DATA_PATH # Default is to pip install the latest "deadline-cloud-worker-agent" package export WORKER_AGENT_WHL_PATH +# DEPRECATED: Use REGION instead # The AWS region to configure the worker for # Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2 -export WORKER_REGION +# export WORKER_REGION # The POSIX user to configure the worker for # Defaults to "deadline-worker" diff --git a/src/deadline_test_fixtures/fixtures.py b/src/deadline_test_fixtures/fixtures.py index 776e810..ea37dcf 100644 --- a/src/deadline_test_fixtures/fixtures.py +++ b/src/deadline_test_fixtures/fixtures.py @@ -12,7 +12,7 @@ import pytest import tempfile from contextlib import ExitStack, contextmanager -from dataclasses import InitVar, dataclass, field, fields, MISSING +from dataclasses import InitVar, dataclass, field, fields, replace, MISSING from typing import Any, Generator, TypeVar from .deadline.client import DeadlineClient @@ -141,6 +141,11 @@ def codeartifact() -> CodeArtifactRepositoryInfo: ) +@pytest.fixture(scope="session") +def region() -> str: + return os.getenv("REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2")) + + @pytest.fixture(scope="session") def service_model() -> Generator[ServiceModel, None, None]: service_model_s3_uri = os.getenv("DEADLINE_SERVICE_MODEL_S3_URI") @@ -168,15 +173,22 @@ def service_model() -> Generator[ServiceModel, None, None]: if not local_model_path: local_model_path = _find_latest_service_model_file("deadline") LOG.info(f"Using service model at: {local_model_path}") - yield ServiceModel.from_json_file(local_model_path) + if local_model_path.endswith(".json"): + yield ServiceModel.from_json_file(local_model_path) + elif local_model_path.endswith(".json.gz"): + yield ServiceModel.from_json_gz_file(local_model_path) + else: + raise RuntimeError( + f"Unsupported service model file format (must be .json or .json.gz): {local_model_path}" + ) @pytest.fixture(scope="session") -def install_service_model(service_model: ServiceModel) -> Generator[str, None, None]: +def install_service_model(service_model: ServiceModel, region: str) -> Generator[str, None, None]: LOG.info("Installing service model and configuring boto to use it for API calls") - with service_model.install() as model_path: - LOG.info(f"Installed service model to {model_path}") - yield model_path + with service_model.install(region) as service_model_install: + LOG.info(f"Installed service model to {service_model_install}") + yield service_model_install @pytest.fixture(scope="session") @@ -365,13 +377,12 @@ def worker_config( deadline_resources: DeadlineResources, codeartifact: CodeArtifactRepositoryInfo, service_model: ServiceModel, + region: str, ) -> DeadlineWorkerConfiguration: """ Builds the configuration for a DeadlineWorker. Environment Variables: - WORKER_REGION: The AWS region to configure the worker for - Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2 WORKER_POSIX_USER: The POSIX user to configure the worker for Defaults to "deadline-worker" WORKER_POSIX_SHARED_GROUP: The shared POSIX group to configure the worker user and job user with @@ -387,6 +398,12 @@ def worker_config( """ file_mappings: list[tuple[str, str]] = [] + # Deprecated environment variable + if os.getenv("WORKER_REGION") is not None: + raise Exception( + "The environment variable WORKER_REGION is no longer supported. Please use REGION instead." + ) + # Prepare the Worker agent Python package worker_agent_whl_path = os.getenv("WORKER_AGENT_WHL_PATH") if worker_agent_whl_path: @@ -410,19 +427,15 @@ def worker_config( LOG.info(f"Using Worker agent package {worker_agent_requirement_specifier}") # Path map the service model - dst_path = posixpath.join("/tmp", "deadline-cloud-service-model.json") - path_mapped_model = ServiceModel( - file_path=dst_path, - api_version=service_model.api_version, - service_name=service_model.service_name, - ) + dst_path = posixpath.join("/tmp", os.path.basename(service_model.file_path)) + path_mapped_model = replace(service_model, file_path=dst_path) LOG.info(f"The service model will be copied to {dst_path} on the Worker environment") file_mappings.append((service_model.file_path, dst_path)) return DeadlineWorkerConfiguration( farm_id=deadline_resources.farm.id, fleet_id=deadline_resources.fleet.id, - region=os.getenv("WORKER_REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2")), + region=region, user=os.getenv("WORKER_POSIX_USER", "deadline-worker"), group=os.getenv("WORKER_POSIX_SHARED_GROUP", "shared-group"), allow_shutdown=True, @@ -438,7 +451,6 @@ def worker_config( @pytest.fixture(scope="session") def worker( request: pytest.FixtureRequest, - deadline_client: DeadlineClient, worker_config: DeadlineWorkerConfiguration, ) -> Generator[DeadlineWorker, None, None]: """ @@ -484,7 +496,6 @@ def worker( worker = EC2InstanceWorker( ec2_client=ec2_client, - deadline_client=deadline_client, s3_client=s3_client, bootstrap_bucket_name=bootstrap_resources.bootstrap_bucket_name, ssm_client=ssm_client, @@ -496,6 +507,10 @@ def worker( ) def stop_worker(): + if os.getenv("KEEP_WORKER_AFTER_FAILURE", "false").lower() == "true": + LOG.info("KEEP_WORKER_AFTER_FAILURE is set, not stopping worker") + return + try: worker.stop() except Exception as e: @@ -509,9 +524,8 @@ def stop_worker(): worker.start() except Exception as e: LOG.exception(f"Failed to start worker: {e}") - if os.getenv("KEEP_WORKER_AFTER_FAILURE", "false").lower() != "true": - LOG.info("Stopping worker because it failed to start") - stop_worker() + LOG.info("Stopping worker because it failed to start") + stop_worker() raise yield worker @@ -550,4 +564,9 @@ def _find_latest_service_model_file(service_name: str) -> str: service_name, loader.determine_latest_version(service_name, "service-2"), "service-2" ) _, service_model_path = loader.load_data_with_path(full_name) - return f"{service_model_path}.json" + service_model_files = glob.glob(f"{service_model_path}.*") + if len(service_model_files) > 1: + raise RuntimeError( + f"Expected exactly one file to match glob '{service_model_path}.*, but got: {service_model_files}" + ) + return service_model_files[0] diff --git a/src/deadline_test_fixtures/models.py b/src/deadline_test_fixtures/models.py index 5229b94..9a2752d 100644 --- a/src/deadline_test_fixtures/models.py +++ b/src/deadline_test_fixtures/models.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - from __future__ import annotations +import gzip import json import os import re @@ -10,7 +10,7 @@ from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Generator, Literal +from typing import Any, Generator, List, Literal, Optional @dataclass(frozen=True) @@ -88,8 +88,7 @@ def path_mappings(self) -> list[tuple[str, str]]: @dataclass(frozen=True) class ServiceModel: file_path: str - api_version: str - service_name: str + model: dict[str, Any] @staticmethod def from_json_file(path: str) -> ServiceModel: @@ -97,23 +96,44 @@ def from_json_file(path: str) -> ServiceModel: model = json.load(f) return ServiceModel( file_path=path, - api_version=model["metadata"]["apiVersion"], - service_name=model["metadata"]["serviceId"], + model=model, + ) + + @staticmethod + def from_json_gz_file(path: str) -> GzipServiceModel: + with gzip.open(path, mode="r") as f: + model = json.load(f) + model = GzipServiceModel( + file_path=path, + model=model, ) + return model @contextmanager - def install(self) -> Generator[str, None, None]: + def install(self, region: str) -> Generator[str, None, None]: """ Copies the model to a temporary directory in the structure expected by boto and sets the AWS_DATA_PATH environment variable to it """ try: old_aws_data_path = os.environ.get("AWS_DATA_PATH") - src_file = Path(self.file_path) + old_endpoint_url = os.environ.get("AWS_ENDPOINT_URL_DEADLINE") + + # Set endpoint URL + os.environ["AWS_ENDPOINT_URL_DEADLINE"] = self.deadline_endpoint_url_fmt_str.format( + region + ) + + # Install service model with tempfile.TemporaryDirectory() as tmpdir: - json_path = Path(tmpdir) / self.service_name / self.api_version / "service-2.json" + json_path = ( + Path(tmpdir) + / self.service_name + / self.api_version + / Path(self.file_path).with_suffix(".json").name + ) json_path.parent.mkdir(parents=True) - json_path.write_text(src_file.read_text()) + json_path.write_text(json.dumps(self.model)) os.environ["AWS_DATA_PATH"] = tmpdir yield str(tmpdir) finally: @@ -121,19 +141,55 @@ def install(self) -> Generator[str, None, None]: os.environ["AWS_DATA_PATH"] = old_aws_data_path else: del os.environ["AWS_DATA_PATH"] + if old_endpoint_url: + os.environ["AWS_ENDPOINT_URL_DEADLINE"] = old_endpoint_url + else: + del os.environ["AWS_ENDPOINT_URL_DEADLINE"] @property - def install_command(self) -> str: - return " ".join( - [ - "aws", - "configure", - "add-model", - "--service-model", - f"file://{self.file_path}", - *(["--service-name", self.service_name] if self.service_name else []), - ] - ) + def install_commands(self) -> List[str]: + return [_get_aws_cli_install_command(self.file_path)] + + @property + def api_version(self) -> str: + return self.model["metadata"]["apiVersion"] + + @property + def service_name(self) -> str: + return self.model["metadata"]["serviceId"] + + @property + def endpoint_prefix(self) -> str: + return self.model["metadata"]["endpointPrefix"] + + @property + def deadline_endpoint_url_fmt_str(self) -> str: + """Format string for the service endpoint URL with one format field for the region""" + return f"https://{self.endpoint_prefix}.{{}}.amazonaws.com" + + +class GzipServiceModel(ServiceModel): + @property + def install_commands(self) -> List[str]: + # Strip the .gz suffix, since it will only be .json + json_path = Path(self.file_path).with_suffix("") + return [ + f"gunzip {self.file_path}", + _get_aws_cli_install_command(str(json_path)), + ] + + +def _get_aws_cli_install_command(model_path: str, service_name: Optional[str] = None) -> str: + return " ".join( + [ + "aws", + "configure", + "add-model", + "--service-model", + f"file://{model_path}", + *(["--service-name", service_name] if service_name else []), + ] + ) @dataclass(frozen=True) diff --git a/test/unit/deadline/test_worker.py b/test/unit/deadline/test_worker.py index 87b4022..6686e91 100644 --- a/test/unit/deadline/test_worker.py +++ b/test/unit/deadline/test_worker.py @@ -58,11 +58,6 @@ def region(boto_config: dict[str, str]) -> str: return boto_config["AWS_DEFAULT_REGION"] -@pytest.fixture -def deadline_client() -> MagicMock: - return MagicMock() - - @pytest.fixture def worker_config(region: str) -> DeadlineWorkerConfiguration: return DeadlineWorkerConfiguration( @@ -88,8 +83,7 @@ def worker_config(region: str) -> DeadlineWorkerConfiguration: ], service_model=ServiceModel( file_path="/tmp/deadline/1234-12-12/service-2.json", - api_version="1234-12-12", - service_name="deadline", + model=MagicMock(), ), ) @@ -151,7 +145,6 @@ def bootstrap_bucket_name(self, region: str) -> str: @pytest.fixture def worker( self, - deadline_client: MagicMock, worker_config: DeadlineWorkerConfiguration, subnet_id: str, security_group_id: str, @@ -166,7 +159,6 @@ def worker( s3_client=boto3.client("s3"), ec2_client=boto3.client("ec2"), ssm_client=boto3.client("ssm"), - deadline_client=deadline_client, configuration=worker_config, )