Skip to content

Commit

Permalink
feat(providers/amazon): add describe_log_streams_async and get_log_ev…
Browse files Browse the repository at this point in the history
…ents_async
  • Loading branch information
Lee-W committed Feb 1, 2024
1 parent 07b8610 commit 7471a1a
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 34 deletions.
86 changes: 85 additions & 1 deletion airflow/providers/amazon/aws/hooks/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
# under the License.
from __future__ import annotations

import asyncio
import warnings
from typing import Generator
from typing import Any, AsyncGenerator, Generator

from botocore.exceptions import ClientError

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
Expand Down Expand Up @@ -151,3 +154,84 @@ def get_log_events(
num_consecutive_empty_response = 0

continuation_token.value = response["nextForwardToken"]

async def describe_log_streams_async(
self, log_group: str, stream_prefix: str, order_by: str, count: int
) -> dict[str, Any] | None:
"""Async function to get the list of log streams for the specified log group.
You can list all the log streams or filter the results by prefix. You can also control
how the results are ordered.
:param log_group: The name of the log group.
:param stream_prefix: The prefix to match.
:param order_by: If the value is LogStreamName , the results are ordered by log stream name.
If the value is LastEventTime , the results are ordered by the event time. The default value is LogStreamName.
:param count: The maximum number of items returned
"""
async with self.async_conn as client:
try:
response: dict[str, Any] = await client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=stream_prefix,
orderBy=order_by,
limit=count,
)
return response
except ClientError as error:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
if error.response["Error"]["Code"] == "ResourceNotFoundException":
return None
raise error

async def get_log_events_async(
self,
log_group: str,
log_stream_name: str,
start_time: int = 0,
skip: int = 0,
start_from_head: bool = True,
) -> AsyncGenerator[Any, dict[str, Any]]:
"""A generator for log items in a single stream. This will yield all the items that are available.
:param log_group: The name of the log group.
:param log_stream_name: The name of the specific stream.
:param start_time: The time stamp value to start reading the logs from (default: 0).
:param skip: The number of log entries to skip at the start (default: 0).
This is for when there are multiple entries at the same timestamp.
:param start_from_head: whether to start from the beginning (True) of the log or
at the end of the log (False).
"""
next_token = None
while True:
if next_token is not None:
token_arg: dict[str, str] = {"nextToken": next_token}
else:
token_arg = {}

async with self.async_conn as client:
response = await client.get_log_events(
logGroupName=log_group,
logStreamName=log_stream_name,
startTime=start_time,
startFromHead=start_from_head,
**token_arg,
)

events = response["events"]
event_count = len(events)

if event_count > skip:
events = events[skip:]
skip = 0
else:
skip -= event_count
events = []

for event in events:
await asyncio.sleep(1)
yield event

if next_token != response["nextForwardToken"]:
next_token = response["nextForwardToken"]
54 changes: 22 additions & 32 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,20 +1349,16 @@ async def describe_training_job_with_log_async(
log_group = "/aws/sagemaker/TrainingJobs"

if len(stream_names) < instance_count:
async with AwsLogsHook(
aws_conn_id=self.aws_conn_id, region_name=self.region_name
).async_conn as logs_client:
streams = await logs_client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + "/",
orderBy="LogStreamName",
limit=instance_count,
)
logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
streams = await logs_hook.describe_log_streams_async(
log_group=log_group,
stream_prefix=job_name + "/",
order_by="LogStreamName",
count=instance_count,
)

stream_names = [s["logStreamName"] for s in streams["logStreams"]] if streams else []
positions.update(
[(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions]
)
stream_names = [s["logStreamName"] for s in streams["logStreams"]] if streams else []
positions.update([(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions])

if len(stream_names) > 0:
async for idx, event in self.get_multi_stream(log_group, stream_names, positions):
Expand Down Expand Up @@ -1408,26 +1404,20 @@ async def get_multi_stream(
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
events: list[Any | None] = []

async with AwsLogsHook(
aws_conn_id=self.aws_conn_id, region_name=self.region_name
).async_conn as logs_client:
event_iters = [
await logs_client.get_log_events(
logGroupName=log_group,
logStreamName=s,
startTime=positions[s].timestamp,
)
for s in streams
]
for event_stream in event_iters:
if not event_stream:
events.append(None)
continue
logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
event_iters = [
await logs_hook.get_log_events_async(log_group, s, positions[s].timestamp, positions[s].skip)
for s in streams
]
for event_stream in event_iters:
if not event_stream:
events.append(None)
continue

try:
events.append(await event_stream.__anext__())
except StopAsyncIteration:
events.append(None)
try:
events.append(await event_stream.__anext__())
except StopAsyncIteration:
events.append(None)

while any(events):
i = argmin(events, lambda x: x["timestamp"] if x else 9999999999) or 0
Expand Down
2 changes: 1 addition & 1 deletion tests/www/views/test_views_rendered.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def test_rendered_template_secret(admin_client, create_dag_run, task_secret):

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
# Handle collection of the test by non-db case
Variable = mock.MagicMock() # type: ignore[misc] # noqa: F811
Variable = mock.MagicMock() # type: ignore[misc]
else:
initial_db_init()

Expand Down

0 comments on commit 7471a1a

Please sign in to comment.