diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index c6dc88e4e86ef7..e2fb960355b645 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -383,6 +383,7 @@ def submit_job( configuration_overrides: dict | None = None, client_request_token: str | None = None, tags: dict | None = None, + retry_max_attempts: int | None = None, ) -> str: """ Submit a job to the EMR Containers API and return the job ID. @@ -402,6 +403,7 @@ def submit_job( :param client_request_token: The client idempotency token of the job run request. Use this if you want to specify a unique ID to prevent two jobs from getting started. :param tags: The tags assigned to job runs. + :param retry_max_attempts: The maximum number of attempts on the job's driver. :return: The ID of the job run request. """ params = { @@ -415,6 +417,10 @@ def submit_job( } if client_request_token: params["clientToken"] = client_request_token + if retry_max_attempts: + params["retryPolicyConfiguration"] = { + "maxAttempts": retry_max_attempts, + } response = self.conn.start_job_run(**params) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 628490b3427ead..68e1c90296add4 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -503,6 +503,8 @@ class EmrContainerOperator(BaseOperator): :param max_tries: Deprecated - use max_polling_attempts instead. :param max_polling_attempts: Maximum number of times to wait for the job run to finish. Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state. + :param job_retry_max_attempts: Maximum number of times to retry when the EMR job fails. + Defaults to None, which disable the retry. :param tags: The tags assigned to job runs. Defaults to None :param deferrable: Run operator in the deferrable mode. @@ -534,6 +536,7 @@ def __init__( max_tries: int | None = None, tags: dict | None = None, max_polling_attempts: int | None = None, + job_retry_max_attempts: int | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs: Any, ) -> None: @@ -549,6 +552,7 @@ def __init__( self.wait_for_completion = wait_for_completion self.poll_interval = poll_interval self.max_polling_attempts = max_polling_attempts + self.job_retry_max_attempts = job_retry_max_attempts self.tags = tags self.job_id: str | None = None self.deferrable = deferrable @@ -583,6 +587,7 @@ def execute(self, context: Context) -> str | None: self.configuration_overrides, self.client_request_token, self.tags, + self.job_retry_max_attempts, ) if self.deferrable: query_status = self.hook.check_query_status(job_id=self.job_id) diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index 368f85d395a4bd..8e94e744d943af 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -71,7 +71,7 @@ def test_execute_without_failure( self.emr_container.execute(None) mock_submit_job.assert_called_once_with( - "test_emr_job", "arn:aws:somerole", "6.3.0-latest", {}, {}, GENERATED_UUID, {} + "test_emr_job", "arn:aws:somerole", "6.3.0-latest", {}, {}, GENERATED_UUID, {}, None ) mock_check_query_status.assert_called_once_with("jobid_123456") assert self.emr_container.release_label == "6.3.0-latest" diff --git a/tests/system/providers/amazon/aws/example_emr_eks.py b/tests/system/providers/amazon/aws/example_emr_eks.py index 28dc7ac3c27e15..428c16f4502eed 100644 --- a/tests/system/providers/amazon/aws/example_emr_eks.py +++ b/tests/system/providers/amazon/aws/example_emr_eks.py @@ -282,6 +282,7 @@ def delete_virtual_cluster(virtual_cluster_id): ) # [END howto_operator_emr_container] job_starter.wait_for_completion = False + job_starter.job_retry_max_attempts = 5 # [START howto_sensor_emr_container] job_waiter = EmrContainerSensor(