From 1294cb4d91eec7815fd6f3b5f6407097db2260ef Mon Sep 17 00:00:00 2001 From: Daniel Neilson <53624638+ddneilson@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:10:35 -0500 Subject: [PATCH] fix: respect service's suggested retryAfter when throttled (#39) When calling a deadline cloud service API and getting a throttle/retry response the exception object may contain a "retryAfterSeconds" field alongside the error. When that field is present, the calling client should treat that as a request to retry in no sooner than the given number of seconds; it is a load-shedding mechanism for the service. We should respect the service's request. Solution: Added to the logic of all of the deadline-cloud API wrappers to have them extract the value of the "retryAfterSeconds" field if it's present, and pass that to our backoff-delay calculator. We use the value as a lower limit on the returned delay. I also made the scheduler use the API wrapper for update_worker; it still had its own implementation that didn't properly handle exceptions. This necessitated adding the ability to interrupt the update_worker's throttled-retries so preserve the functionality at that call site. Signed-off-by: Daniel Neilson <53624638+ddneilson@users.noreply.github.com> Signed-off-by: Graeme McHale --- .../aws/deadline/__init__.py | 72 +++++++++-- .../scheduler/scheduler.py | 68 +++++------ .../startup/bootstrap.py | 7 +- .../startup/entrypoint.py | 3 +- .../test_assume_fleet_role_for_worker.py | 36 ++++++ .../test_assume_queue_role_for_worker.py | 41 +++++++ .../aws/deadline/test_batch_get_job_entity.py | 31 ++++- test/unit/aws/deadline/test_create_worker.py | 32 ++++- test/unit/aws/deadline/test_delete_worker.py | 31 ++++- test/unit/aws/deadline/test_update_worker.py | 114 ++++++++++++++---- .../deadline/test_update_worker_schedule.py | 33 ++++- test/unit/scheduler/test_scheduler.py | 74 +----------- test/unit/startup/test_bootstrap.py | 7 +- 13 files changed, 387 insertions(+), 162 deletions(-) diff --git a/src/deadline_worker_agent/aws/deadline/__init__.py b/src/deadline_worker_agent/aws/deadline/__init__.py index f9d1e97f..564376af 100644 --- a/src/deadline_worker_agent/aws/deadline/__init__.py +++ b/src/deadline_worker_agent/aws/deadline/__init__.py @@ -5,6 +5,7 @@ from typing import Any, Optional from threading import Event from dataclasses import dataclass +import random from botocore.retries.standard import RetryContext from botocore.exceptions import ClientError @@ -12,7 +13,8 @@ from deadline.client.api import get_telemetry_client, TelemetryClient from ..._version import __version__ as version # noqa -from ...startup.config import Configuration, Capabilities +from ...startup.config import Configuration +from ...startup.capabilities import Capabilities from ...boto import DeadlineClient, NoOverflowExponentialBackoff as Backoff from ...api_models import ( AssumeFleetRoleForWorkerResponse, @@ -120,6 +122,18 @@ def _get_error_reason_from_header(response: dict[str, Any]) -> Optional[str]: return response.get("reason", None) +def _get_retry_after_seconds_from_header(response: dict[str, Any]) -> Optional[int]: + return response.get("retryAfterSeconds", None) + + +def _apply_lower_bound_to_delay(delay: float, lower_bound: Optional[float] = None) -> float: + if lower_bound is not None and delay < lower_bound: + # We add just a tiny bit of jitter (20%) to the lower bound to reduce the probability + # of a group of workers all retry-storming in lock-step. + delay = lower_bound + random.uniform(0, 0.2 * lower_bound) + return delay + + def _get_resource_id_and_status_from_conflictexception_header( response: dict[str, Any] ) -> tuple[Optional[str], Optional[str]]: @@ -155,7 +169,10 @@ def assume_fleet_role_for_worker( # Retry: # ThrottlingException, InternalServerException delay = backoff.delay_amount(RetryContext(retry)) - code = e.response.get("Error", {}).get("Code", None) + delay = _apply_lower_bound_to_delay( + delay, _get_retry_after_seconds_from_header(e.response) + ) + code = _get_error_code_from_header(e.response) if code == "ThrottlingException": _logger.info( f"Throttled while attempting to refresh Worker AWS Credentials. Retrying in {delay} seconds..." @@ -216,12 +233,11 @@ def assume_queue_role_for_worker( retry = 0 query_start_time = monotonic() - _logger.info("") # Note: Frozen credentials could expire while doing a retry loop; that's # probably going to manifest as AccessDenied, but I'm not 100% certain. while True: if interrupt_event and interrupt_event.is_set(): - raise DeadlineRequestInterrupted("GetQueueIamCredentials interrupted") + raise DeadlineRequestInterrupted("AssumeQueueRoleForWorker interrupted") try: response = deadline_client.assume_queue_role_for_worker( farmId=farm_id, fleetId=fleet_id, workerId=worker_id, queueId=queue_id @@ -235,7 +251,10 @@ def assume_queue_role_for_worker( # Retry: # ThrottlingException, InternalServerException delay = backoff.delay_amount(RetryContext(retry)) - code = e.response.get("Error", {}).get("Code", None) + delay = _apply_lower_bound_to_delay( + delay, _get_retry_after_seconds_from_header(e.response) + ) + code = _get_error_code_from_header(e.response) if code == "ThrottlingException": _logger.info( f"Throttled while attempting to refresh Worker AWS Credentials. Retrying in {delay} seconds..." @@ -333,7 +352,10 @@ def batch_get_job_entity( # Retry: # ThrottlingException, InternalServerException delay = backoff.delay_amount(RetryContext(retry)) - code = e.response.get("Error", {}).get("Code", None) + delay = _apply_lower_bound_to_delay( + delay, _get_retry_after_seconds_from_header(e.response) + ) + code = _get_error_code_from_header(e.response) if code == "ThrottlingException": _logger.info(f"Throttled calling BatchGetJobEntity. Retrying in {delay} seconds...") elif code == "InternalServerException": @@ -377,6 +399,9 @@ def create_worker( break except ClientError as e: delay = backoff.delay_amount(RetryContext(retry)) + delay = _apply_lower_bound_to_delay( + delay, _get_retry_after_seconds_from_header(e.response) + ) code = _get_error_code_from_header(e.response) if code == "ThrottlingException": _logger.info(f"CreateWorker throttled. Retrying in {delay} seconds...") @@ -444,6 +469,9 @@ def delete_worker( break except ClientError as e: delay = backoff.delay_amount(RetryContext(retry)) + delay = _apply_lower_bound_to_delay( + delay, _get_retry_after_seconds_from_header(e.response) + ) code = _get_error_code_from_header(e.response) if code == "ThrottlingException": _logger.info(f"DeleteWorker throttled. Retrying in {delay} seconds...") @@ -487,16 +515,20 @@ def delete_worker( def update_worker( *, deadline_client: DeadlineClient, - config: Configuration, + farm_id: str, + fleet_id: str, worker_id: str, status: WorkerStatus, + capabilities: Optional[Capabilities] = None, host_properties: Optional[HostProperties] = None, + interrupt_event: Optional[Event] = None, ) -> UpdateWorkerResponse: """Calls the UpdateWorker API to update this Worker's status, capabilities, and/or host properties with the service. Raises: DeadlineRequestConditionallyRecoverableError DeadlineRequestUnrecoverableError + DeadlineRequestInterrupted """ # Retry API call when being throttled @@ -506,26 +538,32 @@ def update_worker( _logger.info(f"Invoking UpdateWorker to set {worker_id} to status={status.value}.") request: dict[str, Any] = dict( - farmId=config.farm_id, - fleetId=config.fleet_id, + farmId=farm_id, + fleetId=fleet_id, workerId=worker_id, - capabilities=config.capabilities.for_update_worker(), status=status.value, ) + if capabilities: + request["capabilities"] = capabilities.for_update_worker() if host_properties: request["hostProperties"] = host_properties + _logger.debug("UpdateWorker request: %s", request) while True: # If true, then we're trying to go to STARTED but have determined that we must first # go to STOPPED must_stop_first = False - _logger.debug("UpdateWorker request: %s", request) + if interrupt_event and interrupt_event.is_set(): + raise DeadlineRequestInterrupted("UpdateWorker interrupted") try: response = deadline_client.update_worker(**request) break except ClientError as e: delay = backoff.delay_amount(RetryContext(retry)) + delay = _apply_lower_bound_to_delay( + delay, _get_retry_after_seconds_from_header(e.response) + ) code = _get_error_code_from_header(e.response) skip_sleep = False @@ -578,7 +616,10 @@ def update_worker( raise DeadlineRequestUnrecoverableError(e) if not skip_sleep: - sleep(delay) + if interrupt_event: + interrupt_event.wait(delay) + else: + sleep(delay) retry += 1 except Exception as e: _logger.error("Failed to start worker %s", worker_id) @@ -589,9 +630,11 @@ def update_worker( try: update_worker( deadline_client=deadline_client, - config=config, + farm_id=farm_id, + fleet_id=fleet_id, worker_id=worker_id, status=WorkerStatus.STOPPED, + capabilities=capabilities, host_properties=host_properties, ) except Exception: @@ -695,6 +738,9 @@ def update_worker_schedule( break except ClientError as e: delay = backoff.delay_amount(RetryContext(retry)) + delay = _apply_lower_bound_to_delay( + delay, _get_retry_after_seconds_from_header(e.response) + ) code = _get_error_code_from_header(e.response) if code == "ThrottlingException": diff --git a/src/deadline_worker_agent/scheduler/scheduler.py b/src/deadline_worker_agent/scheduler/scheduler.py index 357903f5..c5377599 100644 --- a/src/deadline_worker_agent/scheduler/scheduler.py +++ b/src/deadline_worker_agent/scheduler/scheduler.py @@ -12,8 +12,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta, timezone from pathlib import Path -from threading import Event, RLock, Lock -from time import sleep, monotonic +from threading import Event, RLock, Lock, Timer from typing import Callable, Tuple, Union, cast, Optional, Any import logging import os @@ -23,9 +22,9 @@ from openjd.sessions import LOG as OPENJD_SESSION_LOG from openjd.sessions import ActionState, ActionStatus from deadline.job_attachments.asset_sync import AssetSync -from botocore.exceptions import ClientError +from ..aws.deadline import update_worker from ..aws_credentials import QueueBoto3Session, AwsCredentialsRefresher from ..boto import DeadlineClient, Session as BotoSession from ..errors import ServiceShutdown @@ -40,8 +39,10 @@ AssignedSession, UpdateWorkerScheduleResponse, UpdatedSessionActionInfo, + WorkerStatus, ) from ..aws.deadline import ( + DeadlineRequestConditionallyRecoverableError, DeadlineRequestError, DeadlineRequestInterrupted, DeadlineRequestWorkerOfflineError, @@ -415,45 +416,36 @@ def _transition_to_stopping(self, timeout: timedelta) -> None: initiated a worker-initiated drain operation, and that it must not be given additional new tasks to work on. """ - request = dict[str, Any]( - farmId=self._farm_id, - fleetId=self._fleet_id, - workerId=self._worker_id, - status="STOPPING", - ) - - start_time = monotonic() - curr_time = start_time - next_backoff = timedelta(microseconds=200 * 1000) # We're only being given timeout seconds to successfully make this request. # That is because the drain operation may be expedited, and we need to move # fast to get to transitioning to STOPPED state after this. - while (curr_time - start_time) < timeout.total_seconds(): - try: - self._deadline.update_worker(**request) - logger.info("Successfully set Worker state to STOPPING.") - break - except ClientError as e: - code = e.response.get("Error", {}).get("Code", None) - if code == "ThrottlingException" or code == "InternalServerException": - # backoff - curr_time = monotonic() - elapsed_time = curr_time - start_time - max_backoff = max( - timedelta(seconds=0), - timedelta(seconds=(timeout.total_seconds() - elapsed_time)), - ) - backoff = min(max_backoff, next_backoff) - next_backoff = next_backoff * 2 - if backoff <= timedelta(seconds=0): - logger.info("Failed to set Worker state to STOPPING: timeout") - break - sleep(backoff.total_seconds()) - else: - logger.info("Failed to set Worker state to STOPPING.") - logger.exception(e) - break + timeout_event = Event() + timer = Timer(interval=timeout.total_seconds(), function=timeout_event.set) + + try: + update_worker( + deadline_client=self._deadline, + farm_id=self._farm_id, + fleet_id=self._fleet_id, + worker_id=self._worker_id, + status=WorkerStatus.STOPPING, + interrupt_event=timeout_event, + ) + logger.info("Successfully set Worker state to STOPPING.") + except DeadlineRequestInterrupted: + logger.info( + "Timeout reached trying to update Worker to STOPPING status. Proceeding without changing status..." + ) + except ( + DeadlineRequestUnrecoverableError, + DeadlineRequestConditionallyRecoverableError, + ) as exc: + logger.warning( + f"Exception updating Worker to STOPPING status. Continuing with drain operation regardless. Exception: {str(exc)}" + ) + finally: + timer.cancel() def _updated_session_actions( self, diff --git a/src/deadline_worker_agent/startup/bootstrap.py b/src/deadline_worker_agent/startup/bootstrap.py index 48fe78af..b2b6ccdf 100644 --- a/src/deadline_worker_agent/startup/bootstrap.py +++ b/src/deadline_worker_agent/startup/bootstrap.py @@ -327,9 +327,11 @@ def _start_worker( try: response = update_worker( deadline_client=deadline_client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=WorkerStatus.STARTED, + capabilities=config.capabilities, host_properties=host_properties, ) except DeadlineRequestUnrecoverableError: @@ -371,7 +373,8 @@ def _enforce_no_instance_profile_or_stop_worker( try: update_worker( deadline_client=deadline_client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=WorkerStatus.STOPPED, ) diff --git a/src/deadline_worker_agent/startup/entrypoint.py b/src/deadline_worker_agent/startup/entrypoint.py index bf536a00..8936bc93 100644 --- a/src/deadline_worker_agent/startup/entrypoint.py +++ b/src/deadline_worker_agent/startup/entrypoint.py @@ -173,7 +173,8 @@ def filter(self, record: logging.LogRecord) -> bool: try: update_worker( deadline_client=deadline_client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=WorkerStatus.STOPPED, ) diff --git a/test/unit/aws/deadline/test_assume_fleet_role_for_worker.py b/test/unit/aws/deadline/test_assume_fleet_role_for_worker.py index 25f633ab..4533539d 100644 --- a/test/unit/aws/deadline/test_assume_fleet_role_for_worker.py +++ b/test/unit/aws/deadline/test_assume_fleet_role_for_worker.py @@ -98,6 +98,42 @@ def test_retries_when_throttled( sleep_mock.assert_called_once() +@pytest.mark.parametrize("exception_code", ["ThrottlingException", "InternalServerException"]) +def test_respects_retryafter_when_throttled( + client: MagicMock, + farm_id: str, + fleet_id: str, + worker_id: str, + mock_assume_fleet_role_for_worker_response: AssumeFleetRoleForWorkerResponse, + exception_code: str, + sleep_mock: MagicMock, +): + # A test that the time delay for a throttled retry respects the value in the 'retryAfterSeconds' + # property of an exception if one is present. + + # GIVEN + min_retry = 30 + exc = ClientError( + {"Error": {"Code": exception_code, "Message": "A message"}, "retryAfterSeconds": min_retry}, + "AssumeFleetRoleForWorker", + ) + client.assume_fleet_role_for_worker.side_effect = [ + exc, + mock_assume_fleet_role_for_worker_response, + ] + + # WHEN + response = assume_fleet_role_for_worker( + deadline_client=client, farm_id=farm_id, fleet_id=fleet_id, worker_id=worker_id + ) + + # THEN + assert response == mock_assume_fleet_role_for_worker_response + assert client.assume_fleet_role_for_worker.call_count == 2 + sleep_mock.assert_called_once() + assert min_retry <= sleep_mock.call_args.args[0] <= (min_retry + 0.2 * min_retry) + + @pytest.mark.parametrize( "exception_code", [ diff --git a/test/unit/aws/deadline/test_assume_queue_role_for_worker.py b/test/unit/aws/deadline/test_assume_queue_role_for_worker.py index 0f987df5..66cbcf0c 100644 --- a/test/unit/aws/deadline/test_assume_queue_role_for_worker.py +++ b/test/unit/aws/deadline/test_assume_queue_role_for_worker.py @@ -154,6 +154,47 @@ def test_retries_when_throttled( sleep_mock.assert_called_once() +@pytest.mark.parametrize("exception_code", ["ThrottlingException", "InternalServerException"]) +def test_respects_retryafter_when_throttled( + client: MagicMock, + farm_id: str, + fleet_id: str, + worker_id: str, + queue_id: str, + mock_assume_queue_role_for_worker_response: AssumeQueueRoleForWorkerResponse, + exception_code: str, + sleep_mock: MagicMock, +): + # A test that the time delay for a throttled retry respects the value in the 'retryAfterSeconds' + # property of an exception if one is present. + + # GIVEN + min_retry = 30 + exc = ClientError( + {"Error": {"Code": exception_code, "Message": "A message"}, "retryAfterSeconds": min_retry}, + "AssumeQueueRoleForWorker", + ) + client.assume_queue_role_for_worker.side_effect = [ + exc, + mock_assume_queue_role_for_worker_response, + ] + + # WHEN + response = assume_queue_role_for_worker( + deadline_client=client, + farm_id=farm_id, + fleet_id=fleet_id, + worker_id=worker_id, + queue_id=queue_id, + ) + + # THEN + assert response == mock_assume_queue_role_for_worker_response + assert client.assume_queue_role_for_worker.call_count == 2 + sleep_mock.assert_called_once() + assert min_retry <= sleep_mock.call_args.args[0] <= (min_retry + 0.2 * min_retry) + + def test_limited_retries_when_queue_in_conflict( client: MagicMock, farm_id: str, diff --git a/test/unit/aws/deadline/test_batch_get_job_entity.py b/test/unit/aws/deadline/test_batch_get_job_entity.py index 05abc3b9..8d809d84 100644 --- a/test/unit/aws/deadline/test_batch_get_job_entity.py +++ b/test/unit/aws/deadline/test_batch_get_job_entity.py @@ -1,6 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -from typing import Any, Generator +from typing import Any, Generator, Optional from unittest.mock import MagicMock, patch import pytest from botocore.exceptions import ClientError @@ -66,13 +66,14 @@ def test_success( @pytest.mark.parametrize( - "exception", + "exception,min_retry", [ pytest.param( ClientError( {"Error": {"Code": "ThrottlingException", "Message": "A message"}}, "BatchGetJobEntity", ), + None, id="Throttling", ), pytest.param( @@ -80,8 +81,31 @@ def test_success( {"Error": {"Code": "InternalServerException", "Message": "A message"}}, "BatchGetJobEntity", ), + None, id="InternalServer", ), + pytest.param( + ClientError( + { + "Error": {"Code": "ThrottlingException", "Message": "A message"}, + "retryAfterSeconds": 30, + }, + "BatchGetJobEntity", + ), + 30, + id="Throttling-minretry", + ), + pytest.param( + ClientError( + { + "Error": {"Code": "InternalServerException", "Message": "A message"}, + "retryAfterSeconds": 30, + }, + "BatchGetJobEntity", + ), + 30, + id="InternalServer-minretry", + ), ], ) def test_retries_when_appropriate( @@ -90,6 +114,7 @@ def test_retries_when_appropriate( fleet_id: str, worker_id: str, exception: ClientError, + min_retry: Optional[float], sleep_mock: MagicMock, ): # A test that the batch_get_job_entity() function will retry calls to the API when: @@ -111,6 +136,8 @@ def test_retries_when_appropriate( # THEN assert client.batch_get_job_entity.call_count == 2 sleep_mock.assert_called_once() + if min_retry is not None: + assert min_retry <= sleep_mock.call_args.args[0] <= (min_retry + 0.2 * min_retry) @pytest.mark.parametrize( diff --git a/test/unit/aws/deadline/test_create_worker.py b/test/unit/aws/deadline/test_create_worker.py index 69437ec9..a80a5568 100644 --- a/test/unit/aws/deadline/test_create_worker.py +++ b/test/unit/aws/deadline/test_create_worker.py @@ -1,6 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -from typing import Generator +from typing import Generator, Optional from unittest.mock import MagicMock, patch import pytest from botocore.exceptions import ClientError @@ -78,12 +78,13 @@ def test_success( @pytest.mark.parametrize( - "exception", + "exception,min_retry", [ pytest.param( ClientError( {"Error": {"Code": "ThrottlingException", "Message": "A message"}}, "CreateWorker" ), + None, id="Throttling", ), pytest.param( @@ -91,8 +92,31 @@ def test_success( {"Error": {"Code": "InternalServerException", "Message": "A message"}}, "CreateWorker", ), + None, id="InternalServer", ), + pytest.param( + ClientError( + { + "Error": {"Code": "ThrottlingException", "Message": "A message"}, + "retryAfterSeconds": 30, + }, + "CreateWorker", + ), + 30, + id="Throttling-minretry", + ), + pytest.param( + ClientError( + { + "Error": {"Code": "InternalServerException", "Message": "A message"}, + "retryAfterSeconds": 30, + }, + "CreateWorker", + ), + 30, + id="InternalServer-minretry", + ), pytest.param( ClientError( { @@ -105,6 +129,7 @@ def test_success( }, "CreateWorker", ), + None, id="Fleet-CREATE_IN_PROGRESS", ), ], @@ -115,6 +140,7 @@ def test_retries_when_appropriate( mock_create_worker_response: CreateWorkerResponse, host_properties: HostProperties, exception: ClientError, + min_retry: Optional[float], sleep_mock: MagicMock, ): # A test that the create_worker() function will retry calls to the API when: @@ -132,6 +158,8 @@ def test_retries_when_appropriate( assert response == mock_create_worker_response assert client.create_worker.call_count == 2 sleep_mock.assert_called_once() + if min_retry is not None: + assert min_retry <= sleep_mock.call_args.args[0] <= (min_retry + 0.2 * min_retry) @pytest.mark.parametrize( diff --git a/test/unit/aws/deadline/test_delete_worker.py b/test/unit/aws/deadline/test_delete_worker.py index 666e0555..5f626234 100644 --- a/test/unit/aws/deadline/test_delete_worker.py +++ b/test/unit/aws/deadline/test_delete_worker.py @@ -1,6 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -from typing import Generator, Any +from typing import Generator, Any, Optional from unittest.mock import MagicMock, patch import pytest from botocore.exceptions import ClientError @@ -65,12 +65,13 @@ def test_success(client: MagicMock, config: Configuration, worker_id: str) -> No @pytest.mark.parametrize( - "exception", + "exception,min_retry", [ pytest.param( ClientError( {"Error": {"Code": "ThrottlingException", "Message": "A message"}}, "DeleteWorker" ), + None, id="Throttling", ), pytest.param( @@ -78,8 +79,31 @@ def test_success(client: MagicMock, config: Configuration, worker_id: str) -> No {"Error": {"Code": "InternalServerException", "Message": "A message"}}, "DeleteWorker", ), + None, id="InternalServer", ), + pytest.param( + ClientError( + { + "Error": {"Code": "ThrottlingException", "Message": "A message"}, + "retryAfterSeconds": 30, + }, + "DeleteWorker", + ), + 30, + id="Throttling-minretry", + ), + pytest.param( + ClientError( + { + "Error": {"Code": "InternalServerException", "Message": "A message"}, + "retryAfterSeconds": 30, + }, + "DeleteWorker", + ), + 30, + id="InternalServer-minretry", + ), ], ) def test_retries_when_appropriate( @@ -87,6 +111,7 @@ def test_retries_when_appropriate( config: Configuration, worker_id: str, exception: ClientError, + min_retry: Optional[float], sleep_mock: MagicMock, ): # A test that the delete_worker() function will retry calls to the API when: @@ -102,6 +127,8 @@ def test_retries_when_appropriate( # THEN assert client.delete_worker.call_count == 2 sleep_mock.assert_called_once() + if min_retry is not None: + assert min_retry <= sleep_mock.call_args.args[0] <= (min_retry + 0.2 * min_retry) @pytest.mark.parametrize( diff --git a/test/unit/aws/deadline/test_update_worker.py b/test/unit/aws/deadline/test_update_worker.py index 5e14fe90..9d720810 100644 --- a/test/unit/aws/deadline/test_update_worker.py +++ b/test/unit/aws/deadline/test_update_worker.py @@ -7,6 +7,7 @@ from deadline_worker_agent.aws.deadline import ( update_worker, + DeadlineRequestInterrupted, DeadlineRequestUnrecoverableError, DeadlineRequestConditionallyRecoverableError, ) @@ -103,9 +104,11 @@ def test_success( # WHEN response = update_worker( deadline_client=client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=status, + capabilities=config.capabilities, host_properties=host_properties, ) @@ -130,6 +133,46 @@ def test_success( ) +def test_can_interrupt( + client: MagicMock, + config: Configuration, + worker_id: str, + sleep_mock: MagicMock, +) -> None: + # A test that the update_worker() function will cease retries when the interrupt + # event it set. + + # GIVEN + event = MagicMock() + event.is_set.side_effect = [False, True] + dummy_response = {"log": AWSLOGS_LOG_CONFIGURATION} + throttle_exc = ClientError( + {"Error": {"Code": "ThrottlingException", "Message": "A message"}}, + "UpdateWorker", + ) + client.update_worker.side_effect = [ + throttle_exc, + throttle_exc, + dummy_response, + ] + + # WHEN + with pytest.raises(DeadlineRequestInterrupted): + update_worker( + deadline_client=client, + farm_id=config.farm_id, + fleet_id=config.fleet_id, + worker_id=worker_id, + status=WorkerStatus.STOPPING, + interrupt_event=event, + ) + + # THEN + assert client.update_worker.call_count == 1 + event.wait.assert_called_once() + sleep_mock.assert_not_called() + + @pytest.mark.parametrize("conflict_status", ["STOPPING", "NOT_COMPATIBLE"]) def test_updates_to_stopped_if_required( client: MagicMock, @@ -165,10 +208,10 @@ def test_updates_to_stopped_if_required( # WHEN response = update_worker( deadline_client=client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=WorkerStatus.STARTED, - host_properties=HOST_PROPERTIES, ) # THEN @@ -179,25 +222,19 @@ def test_updates_to_stopped_if_required( farmId=config.farm_id, fleetId=config.fleet_id, workerId=worker_id, - capabilities=config.capabilities.for_update_worker(), status=WorkerStatus.STARTED.value, - hostProperties=HOST_PROPERTIES, ), call( farmId=config.farm_id, fleetId=config.fleet_id, workerId=worker_id, - capabilities=config.capabilities.for_update_worker(), status=WorkerStatus.STOPPED.value, - hostProperties=HOST_PROPERTIES, ), call( farmId=config.farm_id, fleetId=config.fleet_id, workerId=worker_id, - capabilities=config.capabilities.for_update_worker(), status=WorkerStatus.STARTED.value, - hostProperties=HOST_PROPERTIES, ), ) ) @@ -245,10 +282,10 @@ def test_does_not_recurse_if_not_started( with pytest.raises(DeadlineRequestUnrecoverableError) as exc_context: update_worker( deadline_client=client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=target_status, - host_properties=HOST_PROPERTIES, ) # THEN @@ -285,10 +322,10 @@ def test_reraises_when_updates_to_stopped( with pytest.raises(DeadlineRequestUnrecoverableError) as exc_context: update_worker( deadline_client=client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=WorkerStatus.STARTED, - host_properties=HOST_PROPERTIES, ) # THEN @@ -299,17 +336,13 @@ def test_reraises_when_updates_to_stopped( farmId=config.farm_id, fleetId=config.fleet_id, workerId=worker_id, - capabilities=config.capabilities.for_update_worker(), status=WorkerStatus.STARTED.value, - hostProperties=HOST_PROPERTIES, ), call( farmId=config.farm_id, fleetId=config.fleet_id, workerId=worker_id, - capabilities=config.capabilities.for_update_worker(), status=WorkerStatus.STOPPED.value, - hostProperties=HOST_PROPERTIES, ), ) ) @@ -317,12 +350,13 @@ def test_reraises_when_updates_to_stopped( @pytest.mark.parametrize( - "exception", + "exception,min_retry", [ pytest.param( ClientError( {"Error": {"Code": "ThrottlingException", "Message": "A message"}}, "UpdateWorker" ), + None, id="Throttling", ), pytest.param( @@ -330,8 +364,31 @@ def test_reraises_when_updates_to_stopped( {"Error": {"Code": "InternalServerException", "Message": "A message"}}, "UpdateWorker", ), + None, id="InternalServer", ), + pytest.param( + ClientError( + { + "Error": {"Code": "ThrottlingException", "Message": "A message"}, + "retryAfterSeconds": 30, + }, + "UpdateWorker", + ), + 30, + id="Throttling-minretry", + ), + pytest.param( + ClientError( + { + "Error": {"Code": "InternalServerException", "Message": "A message"}, + "retryAfterSeconds": 30, + }, + "UpdateWorker", + ), + 30, + id="InternalServer-minretry", + ), pytest.param( ClientError( { @@ -340,6 +397,7 @@ def test_reraises_when_updates_to_stopped( }, "UpdateWorker", ), + None, id="Conflict-CONCURRENT_MODIFICATION", ), pytest.param( @@ -354,6 +412,7 @@ def test_reraises_when_updates_to_stopped( }, "UpdateWorker", ), + None, id="Conflict-STATUS_CONFLICT-worker-ASSOCIATED", ), ], @@ -364,6 +423,7 @@ def test_retries_when_appropriate( worker_id: str, sleep_mock: MagicMock, exception: ClientError, + min_retry: Optional[float], ): # A test that the update_worker() function will retry calls to the API when: # 1. Throttled @@ -378,16 +438,18 @@ def test_retries_when_appropriate( # WHEN response = update_worker( deadline_client=client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=WorkerStatus.STARTED, - host_properties=HOST_PROPERTIES, ) # THEN assert response is expected_response assert client.update_worker.call_count == 2 sleep_mock.assert_called_once() + if min_retry is not None: + assert min_retry <= sleep_mock.call_args.args[0] <= (min_retry + 0.2 * min_retry) def test_not_found_raises_conditionally_recoverable( @@ -411,10 +473,10 @@ def test_not_found_raises_conditionally_recoverable( # WHEN update_worker( deadline_client=client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=WorkerStatus.STARTED, - host_properties=HOST_PROPERTIES, ) # THEN @@ -510,10 +572,10 @@ def test_raises_unrecoverable_error( # WHEN update_worker( deadline_client=client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=WorkerStatus.STARTED, - host_properties=HOST_PROPERTIES, ) # THEN @@ -538,10 +600,10 @@ def test_raises_unexpected_exception( # WHEN update_worker( deadline_client=client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=WorkerStatus.STARTED, - host_properties=HOST_PROPERTIES, ) # THEN diff --git a/test/unit/aws/deadline/test_update_worker_schedule.py b/test/unit/aws/deadline/test_update_worker_schedule.py index 4a1dc060..eda26dbb 100644 --- a/test/unit/aws/deadline/test_update_worker_schedule.py +++ b/test/unit/aws/deadline/test_update_worker_schedule.py @@ -78,8 +78,8 @@ def test_can_interrupt( worker_id: str, sleep_mock: MagicMock, ): - # A test that the update_worker_schedule() function will - # retry calls to the API when throttled. + # A test that the update_worker_schedule() function will cease retries when the interrupt + # event it set. # GIVEN event = MagicMock() @@ -111,13 +111,14 @@ def test_can_interrupt( @pytest.mark.parametrize( - "exception", + "exception,min_retry", [ pytest.param( ClientError( {"Error": {"Code": "ThrottlingException", "Message": "A message"}}, "UpdateWorkerSchedule", ), + None, id="Throttling", ), pytest.param( @@ -125,8 +126,31 @@ def test_can_interrupt( {"Error": {"Code": "InternalServerException", "Message": "A message"}}, "UpdateWorkerSchedule", ), + None, id="InternalServer", ), + pytest.param( + ClientError( + { + "Error": {"Code": "ThrottlingException", "Message": "A message"}, + "retryAfterSeconds": 30, + }, + "UpdateWorkerSchedule", + ), + 30, + id="Throttling-minretry", + ), + pytest.param( + ClientError( + { + "Error": {"Code": "InternalServerException", "Message": "A message"}, + "retryAfterSeconds": 30, + }, + "UpdateWorkerSchedule", + ), + 30, + id="InternalServer-minretry", + ), ], ) def test_retries_when_appropriate( @@ -135,6 +159,7 @@ def test_retries_when_appropriate( fleet_id: str, worker_id: str, exception: ClientError, + min_retry: Optional[float], sleep_mock: MagicMock, ) -> None: # A test that the update_worker_schedule() function will retry calls to the API when: @@ -153,6 +178,8 @@ def test_retries_when_appropriate( assert response == SAMPLE_UPDATE_WORKER_SCHEDULE_RESPONSE assert client.update_worker_schedule.call_count == 2 sleep_mock.assert_called_once() + if min_retry is not None: + assert min_retry <= sleep_mock.call_args.args[0] <= (min_retry + 0.2 * min_retry) @pytest.mark.parametrize( diff --git a/test/unit/scheduler/test_scheduler.py b/test/unit/scheduler/test_scheduler.py index cd54898f..9eda2926 100644 --- a/test/unit/scheduler/test_scheduler.py +++ b/test/unit/scheduler/test_scheduler.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Generator from unittest.mock import ANY, MagicMock, Mock, call, patch -import time from openjd.sessions import ActionState, ActionStatus from botocore.exceptions import ClientError @@ -335,45 +334,13 @@ def test_updates_to_stopping(self, scheduler: WorkerScheduler) -> None: """Most basic test. Do we invoke the correct API with the STOPPING state?""" # GIVEN - with patch.object(scheduler, "_deadline") as mock_deadline_client: - api_mock = MagicMock() - mock_deadline_client.update_worker = api_mock - - # WHEN - scheduler._transition_to_stopping(timeout=timedelta(seconds=1)) - - # THEN - api_mock.assert_called_once() - assert api_mock.call_args.kwargs["status"] == "STOPPING" - - @pytest.mark.parametrize("code", ["ThrottlingException", "InternalServerException"]) - def test_retries_on_exception(self, scheduler: WorkerScheduler, code: str) -> None: - """Test that we retry when getting a retryable exception.""" - - # GIVEN - with patch.object(scheduler, "_deadline") as mock_deadline_client: - exception = ClientError( - error_response={ - "Error": { - "Code": code, - "Message": "A message", - }, - }, - operation_name="OpName", - ) - api_mock = MagicMock() - api_mock.side_effect = ( - exception, - {}, - ) - mock_deadline_client.update_worker = api_mock - + with patch.object(scheduler_mod, "update_worker") as mock_update_worker: # WHEN scheduler._transition_to_stopping(timeout=timedelta(seconds=1)) # THEN - api_mock.assert_called() - assert api_mock.call_count == 2 + mock_update_worker.assert_called_once() + assert mock_update_worker.call_args.kwargs["status"] == "STOPPING" @pytest.mark.parametrize( "code", @@ -406,41 +373,6 @@ def test_exits_on_exception(self, scheduler: WorkerScheduler, code: str) -> None # THEN api_mock.assert_called_once() - def test_limited_backoffs(self, scheduler: WorkerScheduler) -> None: - """Test that we do an increasing-duration backoff when throttled.""" - - # GIVEN - - def side_effect(*args, **kwargs): - time.sleep(0.05) - raise ClientError( - error_response={ - "Error": { - "Code": "ThrottlingException", - "Message": "A message", - }, - }, - operation_name="OpName", - ) - - with ( - patch.object(scheduler, "_deadline") as mock_deadline_client, - patch.object(scheduler_mod, "sleep") as mock_sleep, - ): - api_mock = MagicMock() - api_mock.side_effect = side_effect - mock_deadline_client.update_worker = api_mock - - # WHEN - scheduler._transition_to_stopping(timeout=timedelta(seconds=1)) - - # THEN - api_mock.assert_called() - assert api_mock.call_count > 1 - assert mock_sleep.call_count > 2 - # back-offs are growing in length: - assert mock_sleep.call_args_list[0].args[0] < mock_sleep.call_args_list[1].args[0] - class TestSchedulerSync: """Tests for WorkerScheduler._sync()""" diff --git a/test/unit/startup/test_bootstrap.py b/test/unit/startup/test_bootstrap.py index c847f62c..72c75d0a 100644 --- a/test/unit/startup/test_bootstrap.py +++ b/test/unit/startup/test_bootstrap.py @@ -752,9 +752,11 @@ def test_success( mock_get_host_properties.assert_called_once() update_worker_mock.assert_called_once_with( deadline_client=deadline_client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=WorkerStatus.STARTED, + capabilities=config.capabilities, host_properties=host_properties, ) if not log_config: @@ -982,7 +984,8 @@ def test_instance_profile_attached_stops_worker( mock_enforce_no_instance_profile.assert_called_once_with() update_worker_mock.assert_called_once_with( deadline_client=client, - config=config, + farm_id=config.farm_id, + fleet_id=config.fleet_id, worker_id=worker_id, status=WorkerStatus.STOPPED, )