Skip to content

Commit

Permalink
validate aws service exceptions in waiters (#41941)
Browse files Browse the repository at this point in the history
  • Loading branch information
gopidesupavan committed Sep 4, 2024
1 parent 5fbfe4c commit bfbff66
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 7 deletions.
32 changes: 26 additions & 6 deletions airflow/providers/amazon/aws/utils/waiter_with_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,21 @@ def wait(
try:
waiter.wait(**args, WaiterConfig={"MaxAttempts": 1})
except WaiterError as error:
if "terminal failure" in str(error):
log.error("%s: %s", failure_message, _LazyStatusFormatter(status_args, error.last_response))
error_reason = str(error)
last_response = error.last_response

if "terminal failure" in error_reason:
log.error("%s: %s", failure_message, _LazyStatusFormatter(status_args, last_response))
raise AirflowException(f"{failure_message}: {error}")

if (
"An error occurred" in error_reason
and isinstance(last_response.get("Error"), dict)
and "Code" in last_response.get("Error")
):
raise AirflowException(f"{failure_message}: {error}")

log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response))
log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response))
else:
break
else:
Expand Down Expand Up @@ -122,11 +132,21 @@ async def async_wait(
try:
await waiter.wait(**args, WaiterConfig={"MaxAttempts": 1})
except WaiterError as error:
if "terminal failure" in str(error):
log.error("%s: %s", failure_message, _LazyStatusFormatter(status_args, error.last_response))
error_reason = str(error)
last_response = error.last_response

if "terminal failure" in error_reason:
log.error("%s: %s", failure_message, _LazyStatusFormatter(status_args, last_response))
raise AirflowException(f"{failure_message}: {error}")

if (
"An error occurred" in error_reason
and isinstance(last_response.get("Error"), dict)
and "Code" in last_response.get("Error")
):
raise AirflowException(f"{failure_message}: {error}")

log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response))
log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response))
else:
break
else:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/waiters/stepfunctions.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"state": "success"
},
{
"matcher": "error",
"matcher": "path",
"argument": "status",
"expected": "RUNNING",
"state": "retry"
Expand Down
71 changes: 71 additions & 0 deletions tests/providers/amazon/aws/utils/test_waiter_with_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,41 @@ async def test_async_wait(self, caplog):
assert mock_waiter.wait.call_count == 3
assert caplog.messages == ["test status message: Pending", "test status message: Pending"]

@pytest.mark.asyncio
async def test_async_wait_with_unknown_failure(self):
mock_waiter = mock.MagicMock()
service_exception = WaiterError(
name="test_waiter",
reason="An error occurred",
last_response={
"Error": {
"Message": "Not authorized to perform: states:DescribeExecution on resource",
"Code": "AccessDeniedException",
}
},
)
mock_waiter.wait = AsyncMock()
mock_waiter.wait.side_effect = [service_exception]
with pytest.raises(AirflowException) as exc:
await async_wait(
waiter=mock_waiter,
waiter_delay=0,
waiter_max_attempts=456,
args={"test_arg": "test_value"},
failure_message="test failure message",
status_message="test status message",
status_args=["Status.State"],
)

mock_waiter.wait.assert_called_with(
**{"test_arg": "test_value"},
WaiterConfig={
"MaxAttempts": 1,
},
)
assert "An error occurred" in str(exc)
assert mock_waiter.wait.call_count == 1

@mock.patch("time.sleep")
def test_wait_max_attempts_exceeded(self, mock_sleep, caplog):
mock_sleep.return_value = True
Expand Down Expand Up @@ -187,6 +222,42 @@ def test_wait_with_failure(self, mock_sleep, caplog):
assert mock_waiter.wait.call_count == 4
assert caplog.messages == ["test status message: Pending"] * 3 + ["test failure message: Failure"]

@mock.patch("time.sleep")
def test_wait_with_unknown_failure(self, mock_sleep):
mock_sleep.return_value = True
mock_waiter = mock.MagicMock()
service_exception = WaiterError(
name="test_waiter",
reason="An error occurred",
last_response={
"Error": {
"Message": "Not authorized to perform: states:DescribeExecution on resource",
"Code": "AccessDeniedException",
}
},
)
mock_waiter.wait.side_effect = [service_exception]

with pytest.raises(AirflowException) as exc:
wait(
waiter=mock_waiter,
waiter_delay=123,
waiter_max_attempts=10,
args={"test_arg": "test_value"},
failure_message="test failure message",
status_message="test status message",
status_args=["Status.State"],
)

assert "An error occurred" in str(exc)
mock_waiter.wait.assert_called_with(
**{"test_arg": "test_value"},
WaiterConfig={
"MaxAttempts": 1,
},
)
assert mock_waiter.wait.call_count == 1

@mock.patch("time.sleep")
def test_wait_with_list_response(self, mock_sleep, caplog):
mock_sleep.return_value = True
Expand Down

0 comments on commit bfbff66

Please sign in to comment.