Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
reduce nesting (#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Feb 28, 2024
1 parent ff97b40 commit 44cf8e2
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 78 deletions.
2 changes: 1 addition & 1 deletion docs/gen_examples_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_code_examples(obj: Union[ModuleType, Callable]) -> Set[str]:
for section in parsed_sections:
if section.kind == DocstringSectionKind.examples:
code_example = "\n".join(
(part[1] for part in section.as_dict().get("value", []))
part[1] for part in section.as_dict().get("value", [])
)
if not skip_block_load_code_example(code_example):
code_examples.add(code_example)
Expand Down
141 changes: 64 additions & 77 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,79 +675,12 @@ def _create_task_and_wait_for_start(
task_definition = self._prepare_task_definition(
configuration, region=ecs_client.meta.region_name
)

cached_task_definition_arn = _TASK_DEFINITION_CACHE.get(
flow_run.deployment_id
(
task_definition_arn,
new_task_definition_registered,
) = self._get_or_register_task_definition(
logger, ecs_client, configuration, flow_run, task_definition
)

if cached_task_definition_arn:
# Read the task definition to see if the cached task definition is valid
try:
cached_task_definition = self._retrieve_task_definition(
logger, ecs_client, cached_task_definition_arn
)
except Exception as exc:
logger.warning(
"Failed to retrieve cached task definition"
f" {cached_task_definition_arn!r}: {exc!r}"
)
# Clear from cache
_TASK_DEFINITION_CACHE.pop(flow_run.deployment_id, None)
cached_task_definition_arn = None
else:
if not cached_task_definition["status"] == "ACTIVE":
# Cached task definition is not active
logger.warning(
"Cached task definition"
f" {cached_task_definition_arn!r} is not active"
)
_TASK_DEFINITION_CACHE.pop(flow_run.deployment_id, None)
cached_task_definition_arn = None
elif not self._task_definitions_equal(
task_definition, cached_task_definition
):
# Cached task definition is not valid
logger.warning(
"Cached task definition"
f" {cached_task_definition_arn!r} does not meet"
" requirements"
)
_TASK_DEFINITION_CACHE.pop(flow_run.deployment_id, None)
cached_task_definition_arn = None

# use the family as a fallback if we don't have a local cached definition
if (
configuration.match_latest_revision_in_family
and not cached_task_definition_arn
):
try:
task_definition_from_family = self._retrieve_task_definition(
logger,
ecs_client,
task_definition.get("family", ECS_DEFAULT_FAMILY),
)
except Exception as exc:
logger.warning(
"Failed to retrieve a definition for task family "
f"{task_definition.get('family', ECS_DEFAULT_FAMILY)!r}: "
f"{exc!r}"
)
else:
if self._task_definitions_equal(
task_definition, task_definition_from_family
):
cached_task_definition_arn = task_definition_from_family[
"taskDefinitionArn"
]

if not cached_task_definition_arn:
task_definition_arn = self._register_task_definition(
logger, ecs_client, task_definition
)
new_task_definition_registered = True
else:
task_definition_arn = cached_task_definition_arn

else:
task_definition = self._retrieve_task_definition(
logger, ecs_client, task_definition_arn
Expand All @@ -760,17 +693,13 @@ def _create_task_and_wait_for_start(

self._validate_task_definition(task_definition, configuration)

# Update the cached task definition ARN to avoid re-registering the task
# definition on this worker unless necessary; registration is agressively
# rate limited by AWS
_TASK_DEFINITION_CACHE[flow_run.deployment_id] = task_definition_arn

logger.info(f"Using ECS task definition {task_definition_arn!r}...")
logger.debug(
f"Task definition {json.dumps(task_definition, indent=2, default=str)}"
)

# Prepare the task run request
task_run_request = self._prepare_task_run_request(
configuration,
task_definition,
Expand All @@ -791,7 +720,6 @@ def _create_task_and_wait_for_start(
self._report_task_run_creation_failure(configuration, task_run_request, exc)
raise

# Raises an exception if the task does not start
logger.info("Waiting for ECS task run to start...")
self._wait_for_task_start(
logger,
Expand All @@ -804,6 +732,65 @@ def _create_task_and_wait_for_start(

return task_arn, cluster_arn, task_definition, new_task_definition_registered

def _get_or_register_task_definition(
self,
logger: logging.Logger,
ecs_client: _ECSClient,
configuration: ECSJobConfiguration,
flow_run: FlowRun,
task_definition: dict,
) -> Tuple[str, bool]:
"""Get or register a task definition for the given flow run.
Returns a tuple of the task definition ARN and a bool indicating if the task
definition is newly registered.
"""

cached_task_definition_arn = _TASK_DEFINITION_CACHE.get(flow_run.deployment_id)
new_task_definition_registered = False

if cached_task_definition_arn:
try:
cached_task_definition = self._retrieve_task_definition(
logger, ecs_client, cached_task_definition_arn
)
if not cached_task_definition[
"status"
] == "ACTIVE" or not self._task_definitions_equal(
task_definition, cached_task_definition
):
cached_task_definition_arn = None
except Exception:
cached_task_definition_arn = None

if (
not cached_task_definition_arn
and configuration.match_latest_revision_in_family
):
family_name = task_definition.get("family", ECS_DEFAULT_FAMILY)
try:
task_definition_from_family = self._retrieve_task_definition_by_family(
logger, ecs_client, family_name
)
if task_definition_from_family and self._task_definitions_equal(
task_definition, task_definition_from_family
):
cached_task_definition_arn = task_definition_from_family[
"taskDefinitionArn"
]
except Exception:
pass

if not cached_task_definition_arn:
task_definition_arn = self._register_task_definition(
logger, ecs_client, task_definition
)
new_task_definition_registered = True
else:
task_definition_arn = cached_task_definition_arn

return task_definition_arn, new_task_definition_registered

def _watch_task_and_get_exit_code(
self,
logger: logging.Logger,
Expand Down

0 comments on commit 44cf8e2

Please sign in to comment.