Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Optional
from typing import Any, Iterator, Optional, Tuple

from termcolor import cprint

from llama_stack_client.types import InterleavedContent, ToolResponseMessage


def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
def _process(c) -> str:
def _process(c: Any) -> str:
if isinstance(c, str):
return c
elif hasattr(c, "type"):
Expand All @@ -36,36 +36,38 @@ def __init__(
self,
role: Optional[str] = None,
content: str = "",
end: str = "\n",
color="white",
):
end: Optional[str] = "\n",
color: str = "white",
) -> None:
self.role = role
self.content = content
self.color = color
self.end = "\n" if end is None else end

def __str__(self):
def __str__(self) -> str:
if self.role is not None:
return f"{self.role}> {self.content}"
else:
return f"{self.content}"

def print(self, flush=True):
def print(self, flush: bool = True) -> None:
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)


class TurnStreamEventPrinter:
def __init__(self):
self.previous_event_type = None
self.previous_step_type = None
def __init__(self) -> None:
self.previous_event_type: Optional[str] = None
self.previous_step_type: Optional[str] = None

def yield_printable_events(self, chunk):
def yield_printable_events(self, chunk: Any) -> Iterator[TurnStreamPrintableEvent]:
for printable_event in self._yield_printable_events(chunk, self.previous_event_type, self.previous_step_type):
yield printable_event

self.previous_event_type, self.previous_step_type = self._get_event_type_step_type(chunk)

def _yield_printable_events(self, chunk, previous_event_type=None, previous_step_type=None):
def _yield_printable_events(
self, chunk: Any, previous_event_type: Optional[str] = None, previous_step_type: Optional[str] = None
) -> Iterator[TurnStreamPrintableEvent]:
if hasattr(chunk, "error"):
yield TurnStreamPrintableEvent(role=None, content=chunk.error["message"], color="red")
return
Expand Down Expand Up @@ -151,7 +153,7 @@ def _yield_printable_events(self, chunk, previous_event_type=None, previous_step
color="green",
)

def _get_event_type_step_type(self, chunk):
def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional[str]]:
if hasattr(chunk, "event"):
previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None
previous_step_type = (
Expand All @@ -162,7 +164,7 @@ def _get_event_type_step_type(self, chunk):


class EventLogger:
def log(self, event_generator):
def log(self, event_generator: Iterator[Any]) -> Iterator[TurnStreamPrintableEvent]:
printer = TurnStreamEventPrinter()
for chunk in event_generator:
yield from printer.yield_printable_events(chunk)