Skip to content

Commit

Permalink
fix: set AWS_ENDPOINT_URL_DEADLINE after installing service model
Browse files Browse the repository at this point in the history
Signed-off-by: Jericho Tolentino <68654047+jericht@users.noreply.github.com>
  • Loading branch information
jericht committed Apr 17, 2024
1 parent c3c6d75 commit b4e40b6
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 55 deletions.
7 changes: 4 additions & 3 deletions src/deadline_test_fixtures/deadline/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dataclasses import dataclass, field, InitVar, replace
from typing import Any, ClassVar, Optional, cast

from .client import DeadlineClient
from ..models import (
PipInstall,
PosixSessionUser,
Expand Down Expand Up @@ -56,7 +55,10 @@ def configure_worker_command(*, config: DeadlineWorkerConfiguration) -> str: #

if config.service_model:
cmds.append(
f"runuser -l {config.user} -s /bin/bash -c '{config.service_model.install_command}'"
f"runuser -l {config.user} -s /bin/bash -c '{' && '.join(config.service_model.install_commands)}'"
)
cmds.append(
f'echo "AWS_ENDPOINT_URL_DEADLINE=\\"{config.service_model.deadline_endpoint_url_fmt_str.format(config.region)}\\"" >> /etc/environment'
)

return " && ".join(cmds)
Expand Down Expand Up @@ -146,7 +148,6 @@ class EC2InstanceWorker(DeadlineWorker):
s3_client: botocore.client.BaseClient
ec2_client: botocore.client.BaseClient
ssm_client: botocore.client.BaseClient
deadline_client: DeadlineClient
configuration: DeadlineWorkerConfiguration

instance_id: Optional[str] = field(init=False, default=None)
Expand Down
7 changes: 6 additions & 1 deletion src/deadline_test_fixtures/example_config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ export CODEARTIFACT_REGION

# --- OPTIONAL --- #

# The AWS region to use
# Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2
export REGION

# Extra local path for boto to look for AWS models in
# Does not apply to the worker
export AWS_DATA_PATH
Expand All @@ -38,9 +42,10 @@ export AWS_DATA_PATH
# Default is to pip install the latest "deadline-cloud-worker-agent" package
export WORKER_AGENT_WHL_PATH

# DEPRECATED: Use REGION instead
# The AWS region to configure the worker for
# Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2
export WORKER_REGION
# export WORKER_REGION

# The POSIX user to configure the worker for
# Defaults to "deadline-worker"
Expand Down
61 changes: 40 additions & 21 deletions src/deadline_test_fixtures/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest
import tempfile
from contextlib import ExitStack, contextmanager
from dataclasses import InitVar, dataclass, field, fields, MISSING
from dataclasses import InitVar, dataclass, field, fields, replace, MISSING
from typing import Any, Generator, TypeVar

from .deadline.client import DeadlineClient
Expand Down Expand Up @@ -141,6 +141,11 @@ def codeartifact() -> CodeArtifactRepositoryInfo:
)


@pytest.fixture(scope="session")
def region() -> str:
return os.getenv("REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2"))


@pytest.fixture(scope="session")
def service_model() -> Generator[ServiceModel, None, None]:
service_model_s3_uri = os.getenv("DEADLINE_SERVICE_MODEL_S3_URI")
Expand Down Expand Up @@ -168,15 +173,22 @@ def service_model() -> Generator[ServiceModel, None, None]:
if not local_model_path:
local_model_path = _find_latest_service_model_file("deadline")
LOG.info(f"Using service model at: {local_model_path}")
yield ServiceModel.from_json_file(local_model_path)
if local_model_path.endswith(".json"):
yield ServiceModel.from_json_file(local_model_path)
elif local_model_path.endswith(".json.gz"):
yield ServiceModel.from_json_gz_file(local_model_path)
else:
raise RuntimeError(
f"Unsupported service model file format (must be .json or .json.gz): {local_model_path}"
)


@pytest.fixture(scope="session")
def install_service_model(service_model: ServiceModel) -> Generator[str, None, None]:
def install_service_model(service_model: ServiceModel, region: str) -> Generator[str, None, None]:
LOG.info("Installing service model and configuring boto to use it for API calls")
with service_model.install() as model_path:
LOG.info(f"Installed service model to {model_path}")
yield model_path
with service_model.install(region) as service_model_install:
LOG.info(f"Installed service model to {service_model_install}")
yield service_model_install


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -365,13 +377,12 @@ def worker_config(
deadline_resources: DeadlineResources,
codeartifact: CodeArtifactRepositoryInfo,
service_model: ServiceModel,
region: str,
) -> DeadlineWorkerConfiguration:
"""
Builds the configuration for a DeadlineWorker.
Environment Variables:
WORKER_REGION: The AWS region to configure the worker for
Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2
WORKER_POSIX_USER: The POSIX user to configure the worker for
Defaults to "deadline-worker"
WORKER_POSIX_SHARED_GROUP: The shared POSIX group to configure the worker user and job user with
Expand All @@ -387,6 +398,12 @@ def worker_config(
"""
file_mappings: list[tuple[str, str]] = []

# Deprecated environment variable
if os.getenv("WORKER_REGION") is not None:
raise Exception(
"The environment variable WORKER_REGION is no longer supported. Please use REGION instead."
)

# Prepare the Worker agent Python package
worker_agent_whl_path = os.getenv("WORKER_AGENT_WHL_PATH")
if worker_agent_whl_path:
Expand All @@ -410,19 +427,15 @@ def worker_config(
LOG.info(f"Using Worker agent package {worker_agent_requirement_specifier}")

# Path map the service model
dst_path = posixpath.join("/tmp", "deadline-cloud-service-model.json")
path_mapped_model = ServiceModel(
file_path=dst_path,
api_version=service_model.api_version,
service_name=service_model.service_name,
)
dst_path = posixpath.join("/tmp", os.path.basename(service_model.file_path))
path_mapped_model = replace(service_model, file_path=dst_path)
LOG.info(f"The service model will be copied to {dst_path} on the Worker environment")
file_mappings.append((service_model.file_path, dst_path))

return DeadlineWorkerConfiguration(
farm_id=deadline_resources.farm.id,
fleet_id=deadline_resources.fleet.id,
region=os.getenv("WORKER_REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2")),
region=region,
user=os.getenv("WORKER_POSIX_USER", "deadline-worker"),
group=os.getenv("WORKER_POSIX_SHARED_GROUP", "shared-group"),
allow_shutdown=True,
Expand All @@ -438,7 +451,6 @@ def worker_config(
@pytest.fixture(scope="session")
def worker(
request: pytest.FixtureRequest,
deadline_client: DeadlineClient,
worker_config: DeadlineWorkerConfiguration,
) -> Generator[DeadlineWorker, None, None]:
"""
Expand Down Expand Up @@ -484,7 +496,6 @@ def worker(

worker = EC2InstanceWorker(
ec2_client=ec2_client,
deadline_client=deadline_client,
s3_client=s3_client,
bootstrap_bucket_name=bootstrap_resources.bootstrap_bucket_name,
ssm_client=ssm_client,
Expand All @@ -496,6 +507,10 @@ def worker(
)

def stop_worker():
if os.getenv("KEEP_WORKER_AFTER_FAILURE", "false").lower() == "true":
LOG.info("KEEP_WORKER_AFTER_FAILURE is set, not stopping worker")
return

try:
worker.stop()
except Exception as e:
Expand All @@ -509,9 +524,8 @@ def stop_worker():
worker.start()
except Exception as e:
LOG.exception(f"Failed to start worker: {e}")
if os.getenv("KEEP_WORKER_AFTER_FAILURE", "false").lower() != "true":
LOG.info("Stopping worker because it failed to start")
stop_worker()
LOG.info("Stopping worker because it failed to start")
stop_worker()
raise

yield worker
Expand Down Expand Up @@ -550,4 +564,9 @@ def _find_latest_service_model_file(service_name: str) -> str:
service_name, loader.determine_latest_version(service_name, "service-2"), "service-2"
)
_, service_model_path = loader.load_data_with_path(full_name)
return f"{service_model_path}.json"
service_model_files = glob.glob(f"{service_model_path}.*")
if len(service_model_files) > 1:
raise RuntimeError(
f"Expected exactly one file to match glob '{service_model_path}.*, but got: {service_model_files}"
)
return service_model_files[0]
98 changes: 77 additions & 21 deletions src/deadline_test_fixtures/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from __future__ import annotations

import gzip
import json
import os
import re
Expand All @@ -10,7 +10,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, Literal
from typing import Any, Generator, List, Literal, Optional


@dataclass(frozen=True)
Expand Down Expand Up @@ -88,52 +88,108 @@ def path_mappings(self) -> list[tuple[str, str]]:
@dataclass(frozen=True)
class ServiceModel:
file_path: str
api_version: str
service_name: str
model: dict[str, Any]

@staticmethod
def from_json_file(path: str) -> ServiceModel:
with open(path) as f:
model = json.load(f)
return ServiceModel(
file_path=path,
api_version=model["metadata"]["apiVersion"],
service_name=model["metadata"]["serviceId"],
model=model,
)

@staticmethod
def from_json_gz_file(path: str) -> GzipServiceModel:
with gzip.open(path, mode="r") as f:
model = json.load(f)
model = GzipServiceModel(
file_path=path,
model=model,
)
return model

@contextmanager
def install(self) -> Generator[str, None, None]:
def install(self, region: str) -> Generator[str, None, None]:
"""
Copies the model to a temporary directory in the structure expected by boto
and sets the AWS_DATA_PATH environment variable to it
"""
try:
old_aws_data_path = os.environ.get("AWS_DATA_PATH")
src_file = Path(self.file_path)
old_endpoint_url = os.environ.get("AWS_ENDPOINT_URL_DEADLINE")

# Set endpoint URL
os.environ["AWS_ENDPOINT_URL_DEADLINE"] = self.deadline_endpoint_url_fmt_str.format(
region
)

# Install service model
with tempfile.TemporaryDirectory() as tmpdir:
json_path = Path(tmpdir) / self.service_name / self.api_version / "service-2.json"
json_path = (
Path(tmpdir)
/ self.service_name
/ self.api_version
/ Path(self.file_path).with_suffix(".json").name
)
json_path.parent.mkdir(parents=True)
json_path.write_text(src_file.read_text())
json_path.write_text(json.dumps(self.model))
os.environ["AWS_DATA_PATH"] = tmpdir
yield str(tmpdir)
finally:
if old_aws_data_path:
os.environ["AWS_DATA_PATH"] = old_aws_data_path
else:
del os.environ["AWS_DATA_PATH"]
if old_endpoint_url:
os.environ["AWS_ENDPOINT_URL_DEADLINE"] = old_endpoint_url
else:
del os.environ["AWS_ENDPOINT_URL_DEADLINE"]

@property
def install_command(self) -> str:
return " ".join(
[
"aws",
"configure",
"add-model",
"--service-model",
f"file://{self.file_path}",
*(["--service-name", self.service_name] if self.service_name else []),
]
)
def install_commands(self) -> List[str]:
return [_get_aws_cli_install_command(self.file_path)]

@property
def api_version(self) -> str:
return self.model["metadata"]["apiVersion"]

@property
def service_name(self) -> str:
return self.model["metadata"]["serviceId"]

@property
def endpoint_prefix(self) -> str:
return self.model["metadata"]["endpointPrefix"]

@property
def deadline_endpoint_url_fmt_str(self) -> str:
"""Format string for the service endpoint URL with one format field for the region"""
return f"https://{self.endpoint_prefix}.{{}}.amazonaws.com"


class GzipServiceModel(ServiceModel):
@property
def install_commands(self) -> List[str]:
# Strip the .gz suffix, since it will only be .json
json_path = Path(self.file_path).with_suffix("")
return [
f"gunzip {self.file_path}",
_get_aws_cli_install_command(str(json_path)),
]


def _get_aws_cli_install_command(model_path: str, service_name: Optional[str] = None) -> str:
return " ".join(
[
"aws",
"configure",
"add-model",
"--service-model",
f"file://{model_path}",
*(["--service-name", service_name] if service_name else []),
]
)


@dataclass(frozen=True)
Expand Down
10 changes: 1 addition & 9 deletions test/unit/deadline/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ def region(boto_config: dict[str, str]) -> str:
return boto_config["AWS_DEFAULT_REGION"]


@pytest.fixture
def deadline_client() -> MagicMock:
return MagicMock()


@pytest.fixture
def worker_config(region: str) -> DeadlineWorkerConfiguration:
return DeadlineWorkerConfiguration(
Expand All @@ -88,8 +83,7 @@ def worker_config(region: str) -> DeadlineWorkerConfiguration:
],
service_model=ServiceModel(
file_path="/tmp/deadline/1234-12-12/service-2.json",
api_version="1234-12-12",
service_name="deadline",
model=MagicMock(),
),
)

Expand Down Expand Up @@ -151,7 +145,6 @@ def bootstrap_bucket_name(self, region: str) -> str:
@pytest.fixture
def worker(
self,
deadline_client: MagicMock,
worker_config: DeadlineWorkerConfiguration,
subnet_id: str,
security_group_id: str,
Expand All @@ -166,7 +159,6 @@ def worker(
s3_client=boto3.client("s3"),
ec2_client=boto3.client("ec2"),
ssm_client=boto3.client("ssm"),
deadline_client=deadline_client,
configuration=worker_config,
)

Expand Down

0 comments on commit b4e40b6

Please sign in to comment.