diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index f690d93c..794ff6ac 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -31,7 +31,7 @@ def _process(c) -> str: return _process(content) -class LogEvent: +class TurnStreamPrintableEvent: def __init__( self, role: Optional[str] = None, @@ -54,10 +54,25 @@ def print(self, flush=True): cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) -class EventLogger: +class TurnStreamEventPrinter: + def __init__(self): + self.previous_event_type = None + self.previous_step_type = None + + def process_chunk(self, chunk): + log_event = self._get_log_event( + chunk, self.previous_event_type, self.previous_step_type + ) + self.previous_event_type, self.previous_step_type = ( + self._get_event_type_step_type(chunk) + ) + return log_event + def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None): if hasattr(chunk, "error"): - yield LogEvent(role=None, content=chunk.error["message"], color="red") + yield TurnStreamPrintableEvent( + role=None, content=chunk.error["message"], color="red" + ) return if not hasattr(chunk, "event"): @@ -65,7 +80,9 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non # since it does not produce event but instead # a Message if isinstance(chunk, ToolResponseMessage): - yield LogEvent(role="CustomTool", content=chunk.content, color="green") + yield TurnStreamPrintableEvent( + role="CustomTool", content=chunk.content, color="green" + ) return event = chunk.event @@ -73,7 +90,7 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non if event_type in {"turn_start", "turn_complete"}: # Currently not logging any turn realted info - yield LogEvent(role=None, content="", end="", color="grey") + yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey") return step_type = event.payload.step_type @@ -81,9 +98,11 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non if step_type == "shield_call" and event_type == "step_complete": violation = event.payload.step_details.violation if not violation: - yield LogEvent(role=step_type, content="No Violation", color="magenta") + yield TurnStreamPrintableEvent( + role=step_type, content="No Violation", color="magenta" + ) else: - yield LogEvent( + yield TurnStreamPrintableEvent( role=step_type, content=f"{violation.metadata} {violation.user_message}", color="red", @@ -92,18 +111,20 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non # handle inference if step_type == "inference": if event_type == "step_start": - yield LogEvent(role=step_type, content="", end="", color="yellow") + yield TurnStreamPrintableEvent( + role=step_type, content="", end="", color="yellow" + ) elif event_type == "step_progress": if event.payload.delta.type == "tool_call": if isinstance(event.payload.delta.tool_call, str): - yield LogEvent( + yield TurnStreamPrintableEvent( role=None, content=event.payload.delta.tool_call, end="", color="cyan", ) elif event.payload.delta.type == "text": - yield LogEvent( + yield TurnStreamPrintableEvent( role=None, content=event.payload.delta.text, end="", @@ -111,14 +132,14 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non ) else: # step complete - yield LogEvent(role=None, content="") + yield TurnStreamPrintableEvent(role=None, content="") # handle tool_execution if step_type == "tool_execution" and event_type == "step_complete": # Only print tool calls and responses at the step_complete event details = event.payload.step_details for t in details.tool_calls: - yield LogEvent( + yield TurnStreamPrintableEvent( role=step_type, content=f"Tool:{t.tool_name} Args:{t.arguments}", color="green", @@ -129,13 +150,13 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non inserted_context = interleaved_content_as_str(r.content) content = f"fetched {len(inserted_context)} bytes from memory" - yield LogEvent( + yield TurnStreamPrintableEvent( role=step_type, content=content, color="cyan", ) else: - yield LogEvent( + yield TurnStreamPrintableEvent( role=step_type, content=f"Tool:{r.tool_name} Response:{r.content}", color="green", @@ -154,15 +175,11 @@ def _get_event_type_step_type(self, chunk): return previous_event_type, previous_step_type return None, None - def log(self, event_generator): - previous_event_type = None - previous_step_type = None +class EventLogger: + def log(self, event_generator): + printer = TurnStreamEventPrinter() for chunk in event_generator: - for log_event in self._get_log_event( - chunk, previous_event_type, previous_step_type - ): - yield log_event - previous_event_type, previous_step_type = self._get_event_type_step_type( - chunk - ) + printable_event = printer.process_chunk(chunk) + if printable_event: + yield printable_event diff --git a/src/llama_stack_client/lib/inference/event_logger.py b/src/llama_stack_client/lib/inference/event_logger.py index 333f4dcc..e3bce0de 100644 --- a/src/llama_stack_client/lib/inference/event_logger.py +++ b/src/llama_stack_client/lib/inference/event_logger.py @@ -6,7 +6,7 @@ from termcolor import cprint -class LogEvent: +class InferenceStreamPrintableEvent: def __init__( self, content: str = "", @@ -21,13 +21,24 @@ def print(self, flush=True): cprint(f"{self.content}", color=self.color, end=self.end, flush=flush) +class InferenceStreamLogEventPrinter: + def process_chunk(self, chunk): + event = chunk.event + if event.event_type == "start": + return InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="") + elif event.event_type == "progress": + return InferenceStreamPrintableEvent( + event.delta.text, color="yellow", end="" + ) + elif event.event_type == "complete": + return InferenceStreamPrintableEvent("") + return None + + class EventLogger: def log(self, event_generator): + printer = InferenceStreamLogEventPrinter() for chunk in event_generator: - event = chunk.event - if event.event_type == "start": - yield LogEvent("Assistant> ", color="cyan", end="") - elif event.event_type == "progress": - yield LogEvent(event.delta.text, color="yellow", end="") - elif event.event_type == "complete": - yield LogEvent("") + printable_event = printer.process_chunk(chunk) + if printable_event: + yield printable_event diff --git a/src/llama_stack_client/lib/stream_printer.py b/src/llama_stack_client/lib/stream_printer.py new file mode 100644 index 00000000..89e7c0a4 --- /dev/null +++ b/src/llama_stack_client/lib/stream_printer.py @@ -0,0 +1,28 @@ +from .agents.event_logger import TurnStreamEventPrinter +from .inference.event_logger import InferenceStreamLogEventPrinter + + +class EventStreamPrinter: + @classmethod + def gen(cls, event_generator): + inference_printer = None + turn_printer = None + for chunk in event_generator: + if not hasattr(chunk, "event"): + raise ValueError(f"Unexpected chunk without event: {chunk}") + + event = chunk.event + if hasattr(event, "event_type"): + if not inference_printer: + inference_printer = InferenceStreamLogEventPrinter() + printable_event = inference_printer.process_chunk(chunk) + if printable_event: + yield printable_event + elif hasattr(event, "payload") and hasattr(event.payload, "event_type"): + if not turn_printer: + turn_printer = TurnStreamEventPrinter() + printable_event = turn_printer.process_chunk(chunk) + if printable_event: + yield printable_event + else: + raise ValueError(f"Unsupported event: {event}")