Skip to content

Commit

Permalink
feat(providers/amazon): check sagemaker processing job status before …
Browse files Browse the repository at this point in the history
…deferring
  • Loading branch information
Lee-W committed Jan 10, 2024
1 parent 95a8310 commit 952147a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
15 changes: 14 additions & 1 deletion airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,20 @@ def execute(self, context: Context) -> dict:
raise AirflowException(f"Sagemaker Processing Job creation failed: {response}")

if self.deferrable and self.wait_for_completion:
response = self.hook.describe_processing_job(self.config["ProcessingJobName"])
status = response["ProcessingJobStatus"]
if status in self.hook.failed_states:
raise AirflowException(f"SageMaker job failed because {response['FailureReason']}")
elif status == "Completed":
self.log.info("%s completed successfully.", self.task_id)
return {"Processing": serialize(response)}

timeout = self.execution_timeout
if self.max_ingestion_time:
timeout = datetime.timedelta(seconds=self.max_ingestion_time)

self.defer(
timeout=self.execution_timeout,
timeout=timeout,
trigger=SageMakerTrigger(
job_name=self.config["ProcessingJobName"],
job_type="Processing",
Expand All @@ -304,6 +316,7 @@ def execute_complete(self, context, event=None):
else:
self.log.info(event["message"])
self.serialized_job = serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
self.log.info("%s completed successfully.", self.task_id)
return {"Processing": self.serialized_job}

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage:
Expand Down
60 changes: 57 additions & 3 deletions tests/providers/amazon/aws/operators/test_sagemaker_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def setup_method(self):
check_interval=5,
)

self.defer_processing_config_kwargs = dict(
task_id="test_sagemaker_operator", wait_for_completion=True, check_interval=5, deferrable=True
)

@mock.patch.object(SageMakerHook, "describe_processing_job")
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0)
@mock.patch.object(
Expand Down Expand Up @@ -243,6 +247,9 @@ def test_action_if_job_exists_validation(self, mock_client):
action_if_job_exists="not_fail_or_increment",
)

@mock.patch.object(
SageMakerHook, "describe_processing_job", return_value={"ProcessingJobStatus": "InProgress"}
)
@mock.patch.object(
SageMakerHook,
"create_processing_job",
Expand All @@ -252,17 +259,64 @@ def test_action_if_job_exists_validation(self, mock_client):
},
)
@mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False)
def test_operator_defer(self, mock_job_exists, mock_processing):
def test_operator_defer(self, mock_job_exists, mock_processing, mock_describe):
sagemaker_operator = SageMakerProcessingOperator(
**self.processing_config_kwargs,
**self.defer_processing_config_kwargs,
config=CREATE_PROCESSING_PARAMS,
deferrable=True,
)
sagemaker_operator.wait_for_completion = True
with pytest.raises(TaskDeferred) as exc:
sagemaker_operator.execute(context=None)
assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is not a SagemakerTrigger"

@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator.defer")
@mock.patch.object(
SageMakerHook, "describe_processing_job", return_value={"ProcessingJobStatus": "Completed"}
)
@mock.patch.object(
SageMakerHook,
"create_processing_job",
return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}},
)
@mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False)
def test_operator_complete_before_defer(
self, mock_job_exists, mock_processing, mock_describe, mock_defer
):
sagemaker_operator = SageMakerProcessingOperator(
**self.defer_processing_config_kwargs,
config=CREATE_PROCESSING_PARAMS,
)
sagemaker_operator.execute(context=None)
assert not mock_defer.called

@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator.defer")
@mock.patch.object(
SageMakerHook,
"describe_processing_job",
return_value={"ProcessingJobStatus": "Failed", "FailureReason": "It failed"},
)
@mock.patch.object(
SageMakerHook,
"create_processing_job",
return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}},
)
@mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False)
def test_operator_failed_before_defer(
self,
mock_job_exists,
mock_processing,
mock_describe,
mock_defer,
):
sagemaker_operator = SageMakerProcessingOperator(
**self.defer_processing_config_kwargs,
config=CREATE_PROCESSING_PARAMS,
)
with pytest.raises(AirflowException):
sagemaker_operator.execute(context=None)

assert not mock_defer.called

@mock.patch.object(
SageMakerHook,
"describe_processing_job",
Expand Down

0 comments on commit 952147a

Please sign in to comment.