diff --git a/src/deadline_worker_agent/worker.py b/src/deadline_worker_agent/worker.py index bf352a7a..fdfa5b63 100644 --- a/src/deadline_worker_agent/worker.py +++ b/src/deadline_worker_agent/worker.py @@ -239,9 +239,7 @@ def run(self) -> None: logger.debug("monitor ec2 shutdown future complete") worker_shutdown: WorkerShutdown | None = future.result() # We only stop the other threads if we detected an imminent EC2 shutdown. - # The monitoring thread returns None if: - # 1. The Worker is not on EC2, or IMDS is not turned on - # 2. The monitor thread was stopped by the OS signal handler + # The monitoring thread returns None if the monitor thread was stopped by the OS signal handler if worker_shutdown: self._stop.set() self._scheduler.shutdown( @@ -322,7 +320,7 @@ def _monitor_ec2_shutdown(self) -> WorkerShutdown | None: "IMDS unavailable - unable to monitor for spot interruption or ASG life-cycle " "changes" ) - return None + continue # Check for spot interruption or shutdown if ( diff --git a/test/unit/test_worker.py b/test/unit/test_worker.py index 9f38e12e..c8b81f20 100644 --- a/test/unit/test_worker.py +++ b/test/unit/test_worker.py @@ -266,29 +266,41 @@ def test_loops_until_stopped( assert return_value is None - def test_no_imds( + def test_no_imds_temporarily_continues_till_stop_called( self, worker: Worker, mock_logger: MagicMock, ) -> None: """Asserts that when Worker._get_ec2_metadata_imdsv2_token() returns None which indicates - that IMDS is not available, that Worker._monitor_ec2_shutdown() returns None""" + that IMDS is not available, that Worker._monitor_ec2_shutdown() continues looping until + _stop called""" # GIVEN logger_info: MagicMock = mock_logger.info + wait_side_effect = ([False] * 2) + [True] - with patch.object( - worker, "_get_ec2_metadata_imdsv2_token" - ) as mock_get_ec2_metadata_imdsv2_token: + with ( + patch.object( + worker, "_get_ec2_metadata_imdsv2_token" + ) as mock_get_ec2_metadata_imdsv2_token, + patch.object(worker._stop, "wait", side_effect=wait_side_effect), + ): mock_get_ec2_metadata_imdsv2_token.return_value = None # WHEN result = worker._monitor_ec2_shutdown() # THEN - assert result is None - logger_info.assert_called_once_with( - "IMDS unavailable - unable to monitor for spot interruption or ASG life-cycle changes" + logger_info.assert_has_calls( + [ + call( + "IMDS unavailable - unable to monitor for spot interruption or ASG life-cycle changes" + ), + call( + "IMDS unavailable - unable to monitor for spot interruption or ASG life-cycle changes" + ), + ] ) + assert result is None def test_asg_termination( self,