diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index dff81994..eb9035c4 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Any, Iterator, Optional, Tuple from termcolor import cprint @@ -12,7 +12,7 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str: - def _process(c) -> str: + def _process(c: Any) -> str: if isinstance(c, str): return c elif hasattr(c, "type"): @@ -36,36 +36,38 @@ def __init__( self, role: Optional[str] = None, content: str = "", - end: str = "\n", - color="white", - ): + end: Optional[str] = "\n", + color: str = "white", + ) -> None: self.role = role self.content = content self.color = color self.end = "\n" if end is None else end - def __str__(self): + def __str__(self) -> str: if self.role is not None: return f"{self.role}> {self.content}" else: return f"{self.content}" - def print(self, flush=True): + def print(self, flush: bool = True) -> None: cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) class TurnStreamEventPrinter: - def __init__(self): - self.previous_event_type = None - self.previous_step_type = None + def __init__(self) -> None: + self.previous_event_type: Optional[str] = None + self.previous_step_type: Optional[str] = None - def yield_printable_events(self, chunk): + def yield_printable_events(self, chunk: Any) -> Iterator[TurnStreamPrintableEvent]: for printable_event in self._yield_printable_events(chunk, self.previous_event_type, self.previous_step_type): yield printable_event self.previous_event_type, self.previous_step_type = self._get_event_type_step_type(chunk) - def _yield_printable_events(self, chunk, previous_event_type=None, previous_step_type=None): + def _yield_printable_events( + self, chunk: Any, previous_event_type: Optional[str] = None, previous_step_type: Optional[str] = None + ) -> Iterator[TurnStreamPrintableEvent]: if hasattr(chunk, "error"): yield TurnStreamPrintableEvent(role=None, content=chunk.error["message"], color="red") return @@ -151,7 +153,7 @@ def _yield_printable_events(self, chunk, previous_event_type=None, previous_step color="green", ) - def _get_event_type_step_type(self, chunk): + def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional[str]]: if hasattr(chunk, "event"): previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None previous_step_type = ( @@ -162,7 +164,7 @@ def _get_event_type_step_type(self, chunk): class EventLogger: - def log(self, event_generator): + def log(self, event_generator: Iterator[Any]) -> Iterator[TurnStreamPrintableEvent]: printer = TurnStreamEventPrinter() for chunk in event_generator: yield from printer.yield_printable_events(chunk)