From 7471a1a9640f687a5f117292e89b08c82eb773df Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 31 Jan 2024 18:57:09 +0800 Subject: [PATCH] feat(providers/amazon): add describe_log_streams_async and get_log_events_async --- airflow/providers/amazon/aws/hooks/logs.py | 86 ++++++++++++++++++- .../providers/amazon/aws/hooks/sagemaker.py | 54 +++++------- tests/www/views/test_views_rendered.py | 2 +- 3 files changed, 108 insertions(+), 34 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/logs.py b/airflow/providers/amazon/aws/hooks/logs.py index 5e38ad32e3e4a9..1e7473734fdcfe 100644 --- a/airflow/providers/amazon/aws/hooks/logs.py +++ b/airflow/providers/amazon/aws/hooks/logs.py @@ -17,8 +17,11 @@ # under the License. from __future__ import annotations +import asyncio import warnings -from typing import Generator +from typing import Any, AsyncGenerator, Generator + +from botocore.exceptions import ClientError from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -151,3 +154,84 @@ def get_log_events( num_consecutive_empty_response = 0 continuation_token.value = response["nextForwardToken"] + + async def describe_log_streams_async( + self, log_group: str, stream_prefix: str, order_by: str, count: int + ) -> dict[str, Any] | None: + """Async function to get the list of log streams for the specified log group. + + You can list all the log streams or filter the results by prefix. You can also control + how the results are ordered. + + :param log_group: The name of the log group. + :param stream_prefix: The prefix to match. + :param order_by: If the value is LogStreamName , the results are ordered by log stream name. + If the value is LastEventTime , the results are ordered by the event time. The default value is LogStreamName. + :param count: The maximum number of items returned + """ + async with self.async_conn as client: + try: + response: dict[str, Any] = await client.describe_log_streams( + logGroupName=log_group, + logStreamNamePrefix=stream_prefix, + orderBy=order_by, + limit=count, + ) + return response + except ClientError as error: + # On the very first training job run on an account, there's no log group until + # the container starts logging, so ignore any errors thrown about that + if error.response["Error"]["Code"] == "ResourceNotFoundException": + return None + raise error + + async def get_log_events_async( + self, + log_group: str, + log_stream_name: str, + start_time: int = 0, + skip: int = 0, + start_from_head: bool = True, + ) -> AsyncGenerator[Any, dict[str, Any]]: + """A generator for log items in a single stream. This will yield all the items that are available. + + :param log_group: The name of the log group. + :param log_stream_name: The name of the specific stream. + :param start_time: The time stamp value to start reading the logs from (default: 0). + :param skip: The number of log entries to skip at the start (default: 0). + This is for when there are multiple entries at the same timestamp. + :param start_from_head: whether to start from the beginning (True) of the log or + at the end of the log (False). + """ + next_token = None + while True: + if next_token is not None: + token_arg: dict[str, str] = {"nextToken": next_token} + else: + token_arg = {} + + async with self.async_conn as client: + response = await client.get_log_events( + logGroupName=log_group, + logStreamName=log_stream_name, + startTime=start_time, + startFromHead=start_from_head, + **token_arg, + ) + + events = response["events"] + event_count = len(events) + + if event_count > skip: + events = events[skip:] + skip = 0 + else: + skip -= event_count + events = [] + + for event in events: + await asyncio.sleep(1) + yield event + + if next_token != response["nextForwardToken"]: + next_token = response["nextForwardToken"] diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index b6add4caea9260..12307d82dc6f11 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -1349,20 +1349,16 @@ async def describe_training_job_with_log_async( log_group = "/aws/sagemaker/TrainingJobs" if len(stream_names) < instance_count: - async with AwsLogsHook( - aws_conn_id=self.aws_conn_id, region_name=self.region_name - ).async_conn as logs_client: - streams = await logs_client.describe_log_streams( - logGroupName=log_group, - logStreamNamePrefix=job_name + "/", - orderBy="LogStreamName", - limit=instance_count, - ) + logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + streams = await logs_hook.describe_log_streams_async( + log_group=log_group, + stream_prefix=job_name + "/", + order_by="LogStreamName", + count=instance_count, + ) - stream_names = [s["logStreamName"] for s in streams["logStreams"]] if streams else [] - positions.update( - [(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions] - ) + stream_names = [s["logStreamName"] for s in streams["logStreams"]] if streams else [] + positions.update([(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions]) if len(stream_names) > 0: async for idx, event in self.get_multi_stream(log_group, stream_names, positions): @@ -1408,26 +1404,20 @@ async def get_multi_stream( positions = positions or {s: Position(timestamp=0, skip=0) for s in streams} events: list[Any | None] = [] - async with AwsLogsHook( - aws_conn_id=self.aws_conn_id, region_name=self.region_name - ).async_conn as logs_client: - event_iters = [ - await logs_client.get_log_events( - logGroupName=log_group, - logStreamName=s, - startTime=positions[s].timestamp, - ) - for s in streams - ] - for event_stream in event_iters: - if not event_stream: - events.append(None) - continue + logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + event_iters = [ + await logs_hook.get_log_events_async(log_group, s, positions[s].timestamp, positions[s].skip) + for s in streams + ] + for event_stream in event_iters: + if not event_stream: + events.append(None) + continue - try: - events.append(await event_stream.__anext__()) - except StopAsyncIteration: - events.append(None) + try: + events.append(await event_stream.__anext__()) + except StopAsyncIteration: + events.append(None) while any(events): i = argmin(events, lambda x: x["timestamp"] if x else 9999999999) or 0 diff --git a/tests/www/views/test_views_rendered.py b/tests/www/views/test_views_rendered.py index 5aa8e6ba2f5fb6..7621c2561e61dd 100644 --- a/tests/www/views/test_views_rendered.py +++ b/tests/www/views/test_views_rendered.py @@ -256,7 +256,7 @@ def test_rendered_template_secret(admin_client, create_dag_run, task_secret): if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true": # Handle collection of the test by non-db case - Variable = mock.MagicMock() # type: ignore[misc] # noqa: F811 + Variable = mock.MagicMock() # type: ignore[misc] else: initial_db_init()