Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: set AWS_ENDPOINT_URL_DEADLINE after installing service model #96

Merged
merged 2 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/deadline_test_fixtures/deadline/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
from dataclasses import dataclass, field, InitVar, replace
from typing import Any, ClassVar, Optional, cast

from .client import DeadlineClient
from ..models import (
PipInstall,
PosixSessionUser,
ServiceModel,
)
from ..util import call_api, wait_for

Expand Down Expand Up @@ -54,9 +52,9 @@ def configure_worker_command(*, config: DeadlineWorkerConfiguration) -> str: #
# fmt: on
]

if config.service_model:
if config.service_model_path:
cmds.append(
f"runuser -l {config.user} -s /bin/bash -c '{config.service_model.install_command}'"
f"runuser -l {config.user} -s /bin/bash -c 'aws configure add-model --service-model file://{config.service_model_path}'"
)

return " && ".join(cmds)
Expand Down Expand Up @@ -128,7 +126,7 @@ class DeadlineWorkerConfiguration:
)
start_service: bool = False
no_install_service: bool = False
service_model: ServiceModel | None = None
service_model_path: str | None = None
file_mappings: list[tuple[str, str]] | None = None
"""Mapping of files to copy from host environment to worker environment"""
pre_install_commands: list[str] | None = None
Expand All @@ -146,7 +144,6 @@ class EC2InstanceWorker(DeadlineWorker):
s3_client: botocore.client.BaseClient
ec2_client: botocore.client.BaseClient
ssm_client: botocore.client.BaseClient
deadline_client: DeadlineClient
configuration: DeadlineWorkerConfiguration

instance_id: Optional[str] = field(init=False, default=None)
Expand Down
7 changes: 6 additions & 1 deletion src/deadline_test_fixtures/example_config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ export CODEARTIFACT_REGION

# --- OPTIONAL --- #

# The AWS region to use
# Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2
export REGION

# Extra local path for boto to look for AWS models in
# Does not apply to the worker
export AWS_DATA_PATH
Expand All @@ -38,9 +42,10 @@ export AWS_DATA_PATH
# Default is to pip install the latest "deadline-cloud-worker-agent" package
export WORKER_AGENT_WHL_PATH

# DEPRECATED: Use REGION instead
# The AWS region to configure the worker for
# Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2
export WORKER_REGION
# export WORKER_REGION

# The POSIX user to configure the worker for
# Defaults to "deadline-worker"
Expand Down
102 changes: 65 additions & 37 deletions src/deadline_test_fixtures/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import botocore.loaders
import boto3
import glob
import json
import logging
import os
import pathlib
import posixpath
import pytest
import tempfile
Expand Down Expand Up @@ -141,6 +143,11 @@ def codeartifact() -> CodeArtifactRepositoryInfo:
)


@pytest.fixture(scope="session")
def region() -> str:
return os.getenv("REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2"))


@pytest.fixture(scope="session")
def service_model() -> Generator[ServiceModel, None, None]:
service_model_s3_uri = os.getenv("DEADLINE_SERVICE_MODEL_S3_URI")
Expand Down Expand Up @@ -168,15 +175,22 @@ def service_model() -> Generator[ServiceModel, None, None]:
if not local_model_path:
local_model_path = _find_latest_service_model_file("deadline")
LOG.info(f"Using service model at: {local_model_path}")
yield ServiceModel.from_json_file(local_model_path)
if local_model_path.endswith(".json"):
yield ServiceModel.from_json_file(local_model_path)
elif local_model_path.endswith(".json.gz"):
yield ServiceModel.from_json_gz_file(local_model_path)
else:
raise RuntimeError(
f"Unsupported service model file format (must be .json or .json.gz): {local_model_path}"
)


@pytest.fixture(scope="session")
def install_service_model(service_model: ServiceModel) -> Generator[str, None, None]:
def install_service_model(service_model: ServiceModel, region: str) -> Generator[str, None, None]:
LOG.info("Installing service model and configuring boto to use it for API calls")
with service_model.install() as model_path:
LOG.info(f"Installed service model to {model_path}")
yield model_path
with service_model.install(region) as service_model_install:
LOG.info(f"Installed service model to {service_model_install}")
yield service_model_install


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -365,13 +379,12 @@ def worker_config(
deadline_resources: DeadlineResources,
codeartifact: CodeArtifactRepositoryInfo,
service_model: ServiceModel,
) -> DeadlineWorkerConfiguration:
region: str,
) -> Generator[DeadlineWorkerConfiguration, None, None]:
"""
Builds the configuration for a DeadlineWorker.

Environment Variables:
WORKER_REGION: The AWS region to configure the worker for
Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2
WORKER_POSIX_USER: The POSIX user to configure the worker for
Defaults to "deadline-worker"
WORKER_POSIX_SHARED_GROUP: The shared POSIX group to configure the worker user and job user with
Expand All @@ -387,6 +400,12 @@ def worker_config(
"""
file_mappings: list[tuple[str, str]] = []

# Deprecated environment variable
if os.getenv("WORKER_REGION") is not None:
raise Exception(
"The environment variable WORKER_REGION is no longer supported. Please use REGION instead."
)

# Prepare the Worker agent Python package
worker_agent_whl_path = os.getenv("WORKER_AGENT_WHL_PATH")
if worker_agent_whl_path:
Expand All @@ -410,35 +429,36 @@ def worker_config(
LOG.info(f"Using Worker agent package {worker_agent_requirement_specifier}")

# Path map the service model
dst_path = posixpath.join("/tmp", "deadline-cloud-service-model.json")
path_mapped_model = ServiceModel(
file_path=dst_path,
api_version=service_model.api_version,
service_name=service_model.service_name,
)
LOG.info(f"The service model will be copied to {dst_path} on the Worker environment")
file_mappings.append((service_model.file_path, dst_path))

return DeadlineWorkerConfiguration(
farm_id=deadline_resources.farm.id,
fleet_id=deadline_resources.fleet.id,
region=os.getenv("WORKER_REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2")),
user=os.getenv("WORKER_POSIX_USER", "deadline-worker"),
group=os.getenv("WORKER_POSIX_SHARED_GROUP", "shared-group"),
allow_shutdown=True,
worker_agent_install=PipInstall(
requirement_specifiers=[worker_agent_requirement_specifier],
codeartifact=codeartifact,
),
service_model=path_mapped_model,
file_mappings=file_mappings or None,
)
with tempfile.TemporaryDirectory() as tmpdir:
src_path = pathlib.Path(tmpdir) / f"{service_model.service_name}-service-2.json"

LOG.info(f"Staging service model to {src_path} for uploading to S3")
with src_path.open(mode="w") as f:
json.dump(service_model.model, f)

dst_path = posixpath.join("/tmp", src_path.name)
LOG.info(f"The service model will be copied to {dst_path} on the Worker environment")
file_mappings.append((str(src_path), dst_path))

yield DeadlineWorkerConfiguration(
farm_id=deadline_resources.farm.id,
fleet_id=deadline_resources.fleet.id,
region=region,
user=os.getenv("WORKER_POSIX_USER", "deadline-worker"),
group=os.getenv("WORKER_POSIX_SHARED_GROUP", "shared-group"),
allow_shutdown=True,
worker_agent_install=PipInstall(
requirement_specifiers=[worker_agent_requirement_specifier],
codeartifact=codeartifact,
),
service_model_path=dst_path,
file_mappings=file_mappings or None,
)
jericht marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture(scope="session")
def worker(
request: pytest.FixtureRequest,
deadline_client: DeadlineClient,
worker_config: DeadlineWorkerConfiguration,
) -> Generator[DeadlineWorker, None, None]:
"""
Expand Down Expand Up @@ -484,7 +504,6 @@ def worker(

worker = EC2InstanceWorker(
ec2_client=ec2_client,
deadline_client=deadline_client,
s3_client=s3_client,
bootstrap_bucket_name=bootstrap_resources.bootstrap_bucket_name,
ssm_client=ssm_client,
Expand All @@ -496,6 +515,11 @@ def worker(
)

def stop_worker():
if request.session.testsfailed > 0:
if os.getenv("KEEP_WORKER_AFTER_FAILURE", "false").lower() == "true":
LOG.info("KEEP_WORKER_AFTER_FAILURE is set, not stopping worker")
return

try:
worker.stop()
except Exception as e:
Expand All @@ -509,9 +533,8 @@ def stop_worker():
worker.start()
except Exception as e:
LOG.exception(f"Failed to start worker: {e}")
if os.getenv("KEEP_WORKER_AFTER_FAILURE", "false").lower() != "true":
LOG.info("Stopping worker because it failed to start")
stop_worker()
LOG.info("Stopping worker because it failed to start")
stop_worker()
raise

yield worker
Expand Down Expand Up @@ -550,4 +573,9 @@ def _find_latest_service_model_file(service_name: str) -> str:
service_name, loader.determine_latest_version(service_name, "service-2"), "service-2"
)
_, service_model_path = loader.load_data_with_path(full_name)
return f"{service_model_path}.json"
service_model_files = glob.glob(f"{service_model_path}.*")
if len(service_model_files) > 1:
raise RuntimeError(
f"Expected exactly one file to match glob '{service_model_path}.*, but got: {service_model_files}"
)
return service_model_files[0]
61 changes: 37 additions & 24 deletions src/deadline_test_fixtures/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from __future__ import annotations

import gzip
import json
import os
import re
Expand All @@ -10,7 +10,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, Literal
from typing import Any, Generator, Literal


@dataclass(frozen=True)
Expand Down Expand Up @@ -87,53 +87,66 @@ def path_mappings(self) -> list[tuple[str, str]]:

@dataclass(frozen=True)
class ServiceModel:
file_path: str
api_version: str
service_name: str
model: dict[str, Any]

@staticmethod
def from_json_file(path: str) -> ServiceModel:
with open(path) as f:
model = json.load(f)
return ServiceModel(
file_path=path,
api_version=model["metadata"]["apiVersion"],
service_name=model["metadata"]["serviceId"],
)
return ServiceModel(model=model)

@staticmethod
def from_json_gz_file(path: str) -> ServiceModel:
with gzip.open(path, mode="r") as f:
model = json.load(f)
return ServiceModel(model=model)

@contextmanager
def install(self) -> Generator[str, None, None]:
def install(self, region: str) -> Generator[str, None, None]:
"""
Copies the model to a temporary directory in the structure expected by boto
and sets the AWS_DATA_PATH environment variable to it
"""
try:
old_aws_data_path = os.environ.get("AWS_DATA_PATH")
src_file = Path(self.file_path)
old_endpoint_url = os.environ.get("AWS_ENDPOINT_URL_DEADLINE")

# Set endpoint URL
os.environ["AWS_ENDPOINT_URL_DEADLINE"] = self.endpoint_url_fmt_str.format(region)

# Install service model
with tempfile.TemporaryDirectory() as tmpdir:
json_path = Path(tmpdir) / self.service_name / self.api_version / "service-2.json"
json_path.parent.mkdir(parents=True)
json_path.write_text(src_file.read_text())
json_path.write_text(json.dumps(self.model))
os.environ["AWS_DATA_PATH"] = tmpdir
yield str(tmpdir)
finally:
if old_aws_data_path:
os.environ["AWS_DATA_PATH"] = old_aws_data_path
else:
del os.environ["AWS_DATA_PATH"]
if old_endpoint_url:
os.environ["AWS_ENDPOINT_URL_DEADLINE"] = old_endpoint_url
else:
del os.environ["AWS_ENDPOINT_URL_DEADLINE"]

@property
def install_command(self) -> str:
return " ".join(
[
"aws",
"configure",
"add-model",
"--service-model",
f"file://{self.file_path}",
*(["--service-name", self.service_name] if self.service_name else []),
]
)
def api_version(self) -> str:
return self.model["metadata"]["apiVersion"]

@property
def service_name(self) -> str:
return self.model["metadata"]["serviceId"]

@property
def endpoint_prefix(self) -> str:
return self.model["metadata"]["endpointPrefix"]

@property
def endpoint_url_fmt_str(self) -> str:
"""Format string for the service endpoint URL with one format field for the region"""
return f"https://{self.endpoint_prefix}.{{}}.amazonaws.com"


@dataclass(frozen=True)
Expand Down
Loading