Skip to content

Commit

Permalink
feat: add type hints to event logger util
Browse files Browse the repository at this point in the history
Summary:

Test Plan:
  • Loading branch information
ehhuang committed Feb 12, 2025
1 parent b5dce10 commit 75c6ffb
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Optional
from typing import Any, Iterator, Optional, Tuple

from termcolor import cprint

from llama_stack_client.types import InterleavedContent, ToolResponseMessage


def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
def _process(c) -> str:
def _process(c: Any) -> str:
if isinstance(c, str):
return c
elif hasattr(c, "type"):
Expand All @@ -36,36 +36,38 @@ def __init__(
self,
role: Optional[str] = None,
content: str = "",
end: str = "\n",
color="white",
):
end: Optional[str] = "\n",
color: str = "white",
) -> None:
self.role = role
self.content = content
self.color = color
self.end = "\n" if end is None else end

def __str__(self):
def __str__(self) -> str:
if self.role is not None:
return f"{self.role}> {self.content}"
else:
return f"{self.content}"

def print(self, flush=True):
def print(self, flush: bool = True) -> None:
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)


class TurnStreamEventPrinter:
def __init__(self):
self.previous_event_type = None
self.previous_step_type = None
def __init__(self) -> None:
self.previous_event_type: Optional[str] = None
self.previous_step_type: Optional[str] = None

def yield_printable_events(self, chunk):
def yield_printable_events(self, chunk: Any) -> Iterator[TurnStreamPrintableEvent]:
for printable_event in self._yield_printable_events(chunk, self.previous_event_type, self.previous_step_type):
yield printable_event

self.previous_event_type, self.previous_step_type = self._get_event_type_step_type(chunk)

def _yield_printable_events(self, chunk, previous_event_type=None, previous_step_type=None):
def _yield_printable_events(
self, chunk: Any, previous_event_type: Optional[str] = None, previous_step_type: Optional[str] = None
) -> Iterator[TurnStreamPrintableEvent]:
if hasattr(chunk, "error"):
yield TurnStreamPrintableEvent(role=None, content=chunk.error["message"], color="red")
return
Expand Down Expand Up @@ -151,7 +153,7 @@ def _yield_printable_events(self, chunk, previous_event_type=None, previous_step
color="green",
)

def _get_event_type_step_type(self, chunk):
def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional[str]]:
if hasattr(chunk, "event"):
previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None
previous_step_type = (
Expand All @@ -162,7 +164,7 @@ def _get_event_type_step_type(self, chunk):


class EventLogger:
def log(self, event_generator):
def log(self, event_generator: Iterator[Any]) -> Iterator[TurnStreamPrintableEvent]:
printer = TurnStreamEventPrinter()
for chunk in event_generator:
yield from printer.yield_printable_events(chunk)

0 comments on commit 75c6ffb

Please sign in to comment.