diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index e8f5f0880cc93c..1b4ffc45afe4a7 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -283,8 +283,20 @@ def execute(self, context: Context) -> dict: raise AirflowException(f"Sagemaker Processing Job creation failed: {response}") if self.deferrable and self.wait_for_completion: + response = self.hook.describe_processing_job(self.config["ProcessingJobName"]) + status = response["ProcessingJobStatus"] + if status in self.hook.failed_states: + raise AirflowException(f"SageMaker job failed because {response['FailureReason']}") + elif status == "Completed": + self.log.info("%s completed successfully.", self.task_id) + return {"Processing": serialize(response)} + + timeout = self.execution_timeout + if self.max_ingestion_time: + timeout = datetime.timedelta(seconds=self.max_ingestion_time) + self.defer( - timeout=self.execution_timeout, + timeout=timeout, trigger=SageMakerTrigger( job_name=self.config["ProcessingJobName"], job_type="Processing", @@ -304,6 +316,7 @@ def execute_complete(self, context, event=None): else: self.log.info(event["message"]) self.serialized_job = serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"])) + self.log.info("%s completed successfully.", self.task_id) return {"Processing": self.serialized_job} def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage: diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py index 0135ba13fe300a..3a9c9c21f1aa9d 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py @@ -101,6 +101,10 @@ def setup_method(self): check_interval=5, ) + self.defer_processing_config_kwargs = dict( + task_id="test_sagemaker_operator", wait_for_completion=True, check_interval=5, deferrable=True + ) + @mock.patch.object(SageMakerHook, "describe_processing_job") @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0) @mock.patch.object( @@ -243,6 +247,9 @@ def test_action_if_job_exists_validation(self, mock_client): action_if_job_exists="not_fail_or_increment", ) + @mock.patch.object( + SageMakerHook, "describe_processing_job", return_value={"ProcessingJobStatus": "InProgress"} + ) @mock.patch.object( SageMakerHook, "create_processing_job", @@ -252,17 +259,64 @@ def test_action_if_job_exists_validation(self, mock_client): }, ) @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False) - def test_operator_defer(self, mock_job_exists, mock_processing): + def test_operator_defer(self, mock_job_exists, mock_processing, mock_describe): sagemaker_operator = SageMakerProcessingOperator( - **self.processing_config_kwargs, + **self.defer_processing_config_kwargs, config=CREATE_PROCESSING_PARAMS, - deferrable=True, ) sagemaker_operator.wait_for_completion = True with pytest.raises(TaskDeferred) as exc: sagemaker_operator.execute(context=None) assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is not a SagemakerTrigger" + @mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator.defer") + @mock.patch.object( + SageMakerHook, "describe_processing_job", return_value={"ProcessingJobStatus": "Completed"} + ) + @mock.patch.object( + SageMakerHook, + "create_processing_job", + return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, + ) + @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False) + def test_operator_complete_before_defer( + self, mock_job_exists, mock_processing, mock_describe, mock_defer + ): + sagemaker_operator = SageMakerProcessingOperator( + **self.defer_processing_config_kwargs, + config=CREATE_PROCESSING_PARAMS, + ) + sagemaker_operator.execute(context=None) + assert not mock_defer.called + + @mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator.defer") + @mock.patch.object( + SageMakerHook, + "describe_processing_job", + return_value={"ProcessingJobStatus": "Failed", "FailureReason": "It failed"}, + ) + @mock.patch.object( + SageMakerHook, + "create_processing_job", + return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, + ) + @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False) + def test_operator_failed_before_defer( + self, + mock_job_exists, + mock_processing, + mock_describe, + mock_defer, + ): + sagemaker_operator = SageMakerProcessingOperator( + **self.defer_processing_config_kwargs, + config=CREATE_PROCESSING_PARAMS, + ) + with pytest.raises(AirflowException): + sagemaker_operator.execute(context=None) + + assert not mock_defer.called + @mock.patch.object( SageMakerHook, "describe_processing_job",