diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index 794ff6ac..d20eeb38 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -59,16 +59,19 @@ def __init__(self): self.previous_event_type = None self.previous_step_type = None - def process_chunk(self, chunk): - log_event = self._get_log_event( + def yield_printable_events(self, chunk): + for printable_event in self._yield_printable_events( chunk, self.previous_event_type, self.previous_step_type - ) + ): + yield printable_event + self.previous_event_type, self.previous_step_type = ( self._get_event_type_step_type(chunk) ) - return log_event - def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None): + def _yield_printable_events( + self, chunk, previous_event_type=None, previous_step_type=None + ): if hasattr(chunk, "error"): yield TurnStreamPrintableEvent( role=None, content=chunk.error["message"], color="red" @@ -180,6 +183,4 @@ class EventLogger: def log(self, event_generator): printer = TurnStreamEventPrinter() for chunk in event_generator: - printable_event = printer.process_chunk(chunk) - if printable_event: - yield printable_event + yield from printer.yield_printable_events(chunk) diff --git a/src/llama_stack_client/lib/inference/event_logger.py b/src/llama_stack_client/lib/inference/event_logger.py index e3bce0de..18e97ad2 100644 --- a/src/llama_stack_client/lib/inference/event_logger.py +++ b/src/llama_stack_client/lib/inference/event_logger.py @@ -22,23 +22,20 @@ def print(self, flush=True): class InferenceStreamLogEventPrinter: - def process_chunk(self, chunk): + def yield_printable_events(self, chunk): event = chunk.event if event.event_type == "start": - return InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="") + yield InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="") elif event.event_type == "progress": - return InferenceStreamPrintableEvent( + yield InferenceStreamPrintableEvent( event.delta.text, color="yellow", end="" ) elif event.event_type == "complete": - return InferenceStreamPrintableEvent("") - return None + yield InferenceStreamPrintableEvent("") class EventLogger: def log(self, event_generator): printer = InferenceStreamLogEventPrinter() for chunk in event_generator: - printable_event = printer.process_chunk(chunk) - if printable_event: - yield printable_event + yield from printer.yield_printable_events(chunk) diff --git a/src/llama_stack_client/lib/stream_printer.py b/src/llama_stack_client/lib/stream_printer.py index 89e7c0a4..a08d9663 100644 --- a/src/llama_stack_client/lib/stream_printer.py +++ b/src/llama_stack_client/lib/stream_printer.py @@ -15,14 +15,10 @@ def gen(cls, event_generator): if hasattr(event, "event_type"): if not inference_printer: inference_printer = InferenceStreamLogEventPrinter() - printable_event = inference_printer.process_chunk(chunk) - if printable_event: - yield printable_event + yield from inference_printer.yield_printable_events(chunk) elif hasattr(event, "payload") and hasattr(event.payload, "event_type"): if not turn_printer: turn_printer = TurnStreamEventPrinter() - printable_event = turn_printer.process_chunk(chunk) - if printable_event: - yield printable_event + yield from turn_printer.yield_printable_events(chunk) else: raise ValueError(f"Unsupported event: {event}")