Skip to content

Commit

Permalink
check sagemaker training job status before deferring `SageMakerTraini…
Browse files Browse the repository at this point in the history
…ngOperator` (apache#36685)
  • Loading branch information
Lee-W authored and abhishekbhakat committed Mar 5, 2024
1 parent 1ef9b68 commit 144bb22
Show file tree
Hide file tree
Showing 6 changed files with 459 additions and 30 deletions.
86 changes: 85 additions & 1 deletion airflow/providers/amazon/aws/hooks/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
# under the License.
from __future__ import annotations

import asyncio
import warnings
from typing import Generator
from typing import Any, AsyncGenerator, Generator

from botocore.exceptions import ClientError

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
Expand Down Expand Up @@ -151,3 +154,84 @@ def get_log_events(
num_consecutive_empty_response = 0

continuation_token.value = response["nextForwardToken"]

async def describe_log_streams_async(
self, log_group: str, stream_prefix: str, order_by: str, count: int
) -> dict[str, Any] | None:
"""Async function to get the list of log streams for the specified log group.
You can list all the log streams or filter the results by prefix. You can also control
how the results are ordered.
:param log_group: The name of the log group.
:param stream_prefix: The prefix to match.
:param order_by: If the value is LogStreamName , the results are ordered by log stream name.
If the value is LastEventTime , the results are ordered by the event time. The default value is LogStreamName.
:param count: The maximum number of items returned
"""
async with self.async_conn as client:
try:
response: dict[str, Any] = await client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=stream_prefix,
orderBy=order_by,
limit=count,
)
return response
except ClientError as error:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
if error.response["Error"]["Code"] == "ResourceNotFoundException":
return None
raise error

async def get_log_events_async(
self,
log_group: str,
log_stream_name: str,
start_time: int = 0,
skip: int = 0,
start_from_head: bool = True,
) -> AsyncGenerator[Any, dict[str, Any]]:
"""A generator for log items in a single stream. This will yield all the items that are available.
:param log_group: The name of the log group.
:param log_stream_name: The name of the specific stream.
:param start_time: The time stamp value to start reading the logs from (default: 0).
:param skip: The number of log entries to skip at the start (default: 0).
This is for when there are multiple entries at the same timestamp.
:param start_from_head: whether to start from the beginning (True) of the log or
at the end of the log (False).
"""
next_token = None
while True:
if next_token is not None:
token_arg: dict[str, str] = {"nextToken": next_token}
else:
token_arg = {}

async with self.async_conn as client:
response = await client.get_log_events(
logGroupName=log_group,
logStreamName=log_stream_name,
startTime=start_time,
startFromHead=start_from_head,
**token_arg,
)

events = response["events"]
event_count = len(events)

if event_count > skip:
events = events[skip:]
skip = 0
else:
skip -= event_count
events = []

for event in events:
await asyncio.sleep(1)
yield event

if next_token != response["nextForwardToken"]:
next_token = response["nextForwardToken"]
145 changes: 136 additions & 9 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from collections import Counter, namedtuple
from datetime import datetime
from functools import partial
from typing import Any, Callable, Generator, cast
from typing import Any, AsyncGenerator, Callable, Generator, cast

from asgiref.sync import sync_to_async
from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
Expand Down Expand Up @@ -310,10 +311,12 @@ def create_training_job(
max_ingestion_time,
)

billable_time = (
describe_response["TrainingEndTime"] - describe_response["TrainingStartTime"]
) * describe_response["ResourceConfig"]["InstanceCount"]
self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1)
billable_seconds = SageMakerHook.count_billable_seconds(
training_start_time=describe_response["TrainingStartTime"],
training_end_time=describe_response["TrainingEndTime"],
instance_count=describe_response["ResourceConfig"]["InstanceCount"],
)
self.log.info("Billable seconds: %d", billable_seconds)

return response

Expand Down Expand Up @@ -811,10 +814,12 @@ def check_training_status_with_log(
if status in failed_states:
reason = last_description.get("FailureReason", "(No reason provided)")
raise AirflowException(f"Error training {job_name}: {status} Reason: {reason}")
billable_time = (
last_description["TrainingEndTime"] - last_description["TrainingStartTime"]
) * instance_count
self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1)
billable_seconds = SageMakerHook.count_billable_seconds(
training_start_time=last_description["TrainingStartTime"],
training_end_time=last_description["TrainingEndTime"],
instance_count=instance_count,
)
self.log.info("Billable seconds: %d", billable_seconds)

def list_training_jobs(
self, name_contains: str | None = None, max_results: int | None = None, **kwargs
Expand Down Expand Up @@ -1300,3 +1305,125 @@ def create_auto_ml_job(
if "BestCandidate" in res:
return res["BestCandidate"]
return None

@staticmethod
def count_billable_seconds(
training_start_time: datetime, training_end_time: datetime, instance_count: int
) -> int:
billable_time = (training_end_time - training_start_time) * instance_count
return int(billable_time.total_seconds()) + 1

async def describe_training_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the training job info associated with the name.
:param job_name: the name of the training job
"""
async with self.async_conn as client:
response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
return response

async def describe_training_job_with_log_async(
self,
job_name: str,
positions: dict[str, Any],
stream_names: list[str],
instance_count: int,
state: int,
last_description: dict[str, Any],
last_describe_job_call: float,
) -> tuple[int, dict[str, Any], float]:
"""
Return the training job info associated with job_name and print CloudWatch logs.
:param job_name: name of the job to check status
:param positions: A list of pairs of (timestamp, skip) which represents the last record
read from each stream.
:param stream_names: A list of the log stream names. The position of the stream in this list is
the stream number.
:param instance_count: Count of the instance created for the job initially
:param state: log state
:param last_description: Latest description of the training job
:param last_describe_job_call: previous job called time
"""
log_group = "/aws/sagemaker/TrainingJobs"

if len(stream_names) < instance_count:
logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
streams = await logs_hook.describe_log_streams_async(
log_group=log_group,
stream_prefix=job_name + "/",
order_by="LogStreamName",
count=instance_count,
)

stream_names = [s["logStreamName"] for s in streams["logStreams"]] if streams else []
positions.update([(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions])

if len(stream_names) > 0:
async for idx, event in self.get_multi_stream(log_group, stream_names, positions):
self.log.info(event["message"])
ts, count = positions[stream_names[idx]]
if event["timestamp"] == ts:
positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
else:
positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1)

if state == LogState.COMPLETE:
return state, last_description, last_describe_job_call

if state == LogState.JOB_COMPLETE:
state = LogState.COMPLETE
elif time.time() - last_describe_job_call >= 30:
description = await self.describe_training_job_async(job_name)
last_describe_job_call = time.time()

if await sync_to_async(secondary_training_status_changed)(description, last_description):
self.log.info(
await sync_to_async(secondary_training_status_message)(description, last_description)
)
last_description = description

status = description["TrainingJobStatus"]

if status not in self.non_terminal_states:
state = LogState.JOB_COMPLETE
return state, last_description, last_describe_job_call

async def get_multi_stream(
self, log_group: str, streams: list[str], positions: dict[str, Any]
) -> AsyncGenerator[Any, tuple[int, Any | None]]:
"""Iterate over the available events coming and interleaving the events from each stream so they're yielded in timestamp order.
:param log_group: The name of the log group.
:param streams: A list of the log stream names. The position of the stream in this list is
the stream number.
:param positions: A list of pairs of (timestamp, skip) which represents the last record
read from each stream.
"""
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
events: list[Any | None] = []

logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
event_iters = [
logs_hook.get_log_events_async(log_group, s, positions[s].timestamp, positions[s].skip)
for s in streams
]
for event_stream in event_iters:
if not event_stream:
events.append(None)
continue

try:
events.append(await event_stream.__anext__())
except StopAsyncIteration:
events.append(None)

while any(events):
i = argmin(events, lambda x: x["timestamp"] if x else 9999999999) or 0
yield i, events[i]

try:
events[i] = await event_iters[i].__anext__()
except StopAsyncIteration:
events[i] = None
89 changes: 74 additions & 15 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.hooks.sagemaker import (
LogState,
SageMakerHook,
secondary_training_status_message,
)
from airflow.providers.amazon.aws.triggers.sagemaker import (
SageMakerPipelineTrigger,
SageMakerTrainingPrintLogTrigger,
SageMakerTrigger,
)
from airflow.providers.amazon.aws.utils import trim_none_values
Expand Down Expand Up @@ -899,9 +904,11 @@ def execute(self, context: Context) -> dict:
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
timeout=datetime.timedelta(seconds=self.max_ingestion_time)
if self.max_ingestion_time is not None
else None,
timeout=(
datetime.timedelta(seconds=self.max_ingestion_time)
if self.max_ingestion_time is not None
else None
),
)
description = {} # never executed but makes static checkers happy
elif self.wait_for_completion:
Expand Down Expand Up @@ -1085,28 +1092,80 @@ def execute(self, context: Context) -> dict:
raise AirflowException(f"Sagemaker Training Job creation failed: {response}")

if self.deferrable and self.wait_for_completion:
self.defer(
timeout=self.execution_timeout,
trigger=SageMakerTrigger(
description = self.hook.describe_training_job(self.config["TrainingJobName"])
status = description["TrainingJobStatus"]

if self.print_log:
instance_count = description["ResourceConfig"]["InstanceCount"]
last_describe_job_call = time.monotonic()
job_already_completed = status not in self.hook.non_terminal_states
_, description, last_describe_job_call = self.hook.describe_training_job_with_log(
self.config["TrainingJobName"],
{},
[],
instance_count,
LogState.COMPLETE if job_already_completed else LogState.TAILING,
description,
last_describe_job_call,
)
self.log.info(secondary_training_status_message(description, None))

if status in self.hook.failed_states:
reason = description.get("FailureReason", "(No reason provided)")
raise AirflowException(f"SageMaker job failed because {reason}")
elif status == "Completed":
log_message = f"{self.task_id} completed successfully."
if self.print_log:
billable_seconds = SageMakerHook.count_billable_seconds(
training_start_time=description["TrainingStartTime"],
training_end_time=description["TrainingEndTime"],
instance_count=instance_count,
)
log_message = f"Billable seconds: {billable_seconds}\n{log_message}"
self.log.info(log_message)
return {"Training": serialize(description)}

timeout = self.execution_timeout
if self.max_ingestion_time:
timeout = datetime.timedelta(seconds=self.max_ingestion_time)

trigger: SageMakerTrainingPrintLogTrigger | SageMakerTrigger
if self.print_log:
trigger = SageMakerTrainingPrintLogTrigger(
job_name=self.config["TrainingJobName"],
poke_interval=self.check_interval,
aws_conn_id=self.aws_conn_id,
)
else:
trigger = SageMakerTrigger(
job_name=self.config["TrainingJobName"],
job_type="Training",
poke_interval=self.check_interval,
max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
)

self.defer(
timeout=timeout,
trigger=trigger,
method_name="execute_complete",
)

self.serialized_training_data = serialize(
self.hook.describe_training_job(self.config["TrainingJobName"])
)
return {"Training": self.serialized_training_data}
return self.serialize_result()

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
if event is None:
err_msg = "Trigger error: event is None"
self.log.error(err_msg)
raise AirflowException(err_msg)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info(event["message"])

self.log.info(event["message"])
return self.serialize_result()

def serialize_result(self) -> dict[str, dict]:
self.serialized_training_data = serialize(
self.hook.describe_training_job(self.config["TrainingJobName"])
)
Expand Down
Loading

0 comments on commit 144bb22

Please sign in to comment.