Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry configuration in EmrContainerOperator #37426

Merged
merged 2 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def submit_job(
configuration_overrides: dict | None = None,
client_request_token: str | None = None,
tags: dict | None = None,
retry_max_attempts: int | None = None,
) -> str:
"""
Submit a job to the EMR Containers API and return the job ID.
Expand All @@ -402,6 +403,7 @@ def submit_job(
:param client_request_token: The client idempotency token of the job run request.
Use this if you want to specify a unique ID to prevent two jobs from getting started.
:param tags: The tags assigned to job runs.
:param retry_max_attempts: The maximum number of attempts on the job's driver.
:return: The ID of the job run request.
"""
params = {
Expand All @@ -415,6 +417,10 @@ def submit_job(
}
if client_request_token:
params["clientToken"] = client_request_token
if retry_max_attempts:
params["retryPolicyConfiguration"] = {
"maxAttempts": retry_max_attempts,
}

response = self.conn.start_job_run(**params)

Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ class EmrContainerOperator(BaseOperator):
:param max_tries: Deprecated - use max_polling_attempts instead.
:param max_polling_attempts: Maximum number of times to wait for the job run to finish.
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
:param job_retry_max_attempts: Maximum number of times to retry when the EMR job fails.
Defaults to None, which disable the retry.
:param tags: The tags assigned to job runs.
Defaults to None
:param deferrable: Run operator in the deferrable mode.
Expand Down Expand Up @@ -534,6 +536,7 @@ def __init__(
max_tries: int | None = None,
tags: dict | None = None,
max_polling_attempts: int | None = None,
job_retry_max_attempts: int | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
) -> None:
Expand All @@ -549,6 +552,7 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.poll_interval = poll_interval
self.max_polling_attempts = max_polling_attempts
self.job_retry_max_attempts = job_retry_max_attempts
self.tags = tags
self.job_id: str | None = None
self.deferrable = deferrable
Expand Down Expand Up @@ -583,6 +587,7 @@ def execute(self, context: Context) -> str | None:
self.configuration_overrides,
self.client_request_token,
self.tags,
self.job_retry_max_attempts,
)
if self.deferrable:
query_status = self.hook.check_query_status(job_id=self.job_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_execute_without_failure(
self.emr_container.execute(None)

mock_submit_job.assert_called_once_with(
"test_emr_job", "arn:aws:somerole", "6.3.0-latest", {}, {}, GENERATED_UUID, {}
"test_emr_job", "arn:aws:somerole", "6.3.0-latest", {}, {}, GENERATED_UUID, {}, None
)
mock_check_query_status.assert_called_once_with("jobid_123456")
assert self.emr_container.release_label == "6.3.0-latest"
Expand Down
1 change: 1 addition & 0 deletions tests/system/providers/amazon/aws/example_emr_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def delete_virtual_cluster(virtual_cluster_id):
)
# [END howto_operator_emr_container]
job_starter.wait_for_completion = False
job_starter.job_retry_max_attempts = 5

# [START howto_sensor_emr_container]
job_waiter = EmrContainerSensor(
Expand Down