From 8c856024cb0b8c5a8fbed69ca8c03685d6a76b12 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 30 Jan 2025 11:02:00 -0800 Subject: [PATCH 1/4] Unify client SDK loggers --- .../lib/agents/event_logger.py | 63 ++++++++++++------- .../lib/inference/event_logger.py | 25 +++++--- src/llama_stack_client/lib/stream_logger.py | 27 ++++++++ 3 files changed, 84 insertions(+), 31 deletions(-) create mode 100644 src/llama_stack_client/lib/stream_logger.py diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index f690d93c..68e02538 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -31,7 +31,7 @@ def _process(c) -> str: return _process(content) -class LogEvent: +class TurnStreamLogEvent: def __init__( self, role: Optional[str] = None, @@ -54,10 +54,25 @@ def print(self, flush=True): cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) -class EventLogger: +class TurnStreamEventLogger: + def __init__(self): + self.previous_event_type = None + self.previous_step_type = None + + def process_chunk(self, chunk): + log_event = self._get_log_event( + chunk, self.previous_event_type, self.previous_step_type + ) + self.previous_event_type, self.previous_step_type = ( + self._get_event_type_step_type(chunk) + ) + return log_event + def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None): if hasattr(chunk, "error"): - yield LogEvent(role=None, content=chunk.error["message"], color="red") + yield TurnStreamLogEvent( + role=None, content=chunk.error["message"], color="red" + ) return if not hasattr(chunk, "event"): @@ -65,7 +80,9 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non # since it does not produce event but instead # a Message if isinstance(chunk, ToolResponseMessage): - yield LogEvent(role="CustomTool", content=chunk.content, color="green") + yield TurnStreamLogEvent( + role="CustomTool", content=chunk.content, color="green" + ) return event = chunk.event @@ -73,7 +90,7 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non if event_type in {"turn_start", "turn_complete"}: # Currently not logging any turn realted info - yield LogEvent(role=None, content="", end="", color="grey") + yield TurnStreamLogEvent(role=None, content="", end="", color="grey") return step_type = event.payload.step_type @@ -81,9 +98,11 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non if step_type == "shield_call" and event_type == "step_complete": violation = event.payload.step_details.violation if not violation: - yield LogEvent(role=step_type, content="No Violation", color="magenta") + yield TurnStreamLogEvent( + role=step_type, content="No Violation", color="magenta" + ) else: - yield LogEvent( + yield TurnStreamLogEvent( role=step_type, content=f"{violation.metadata} {violation.user_message}", color="red", @@ -92,18 +111,20 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non # handle inference if step_type == "inference": if event_type == "step_start": - yield LogEvent(role=step_type, content="", end="", color="yellow") + yield TurnStreamLogEvent( + role=step_type, content="", end="", color="yellow" + ) elif event_type == "step_progress": if event.payload.delta.type == "tool_call": if isinstance(event.payload.delta.tool_call, str): - yield LogEvent( + yield TurnStreamLogEvent( role=None, content=event.payload.delta.tool_call, end="", color="cyan", ) elif event.payload.delta.type == "text": - yield LogEvent( + yield TurnStreamLogEvent( role=None, content=event.payload.delta.text, end="", @@ -111,14 +132,14 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non ) else: # step complete - yield LogEvent(role=None, content="") + yield TurnStreamLogEvent(role=None, content="") # handle tool_execution if step_type == "tool_execution" and event_type == "step_complete": # Only print tool calls and responses at the step_complete event details = event.payload.step_details for t in details.tool_calls: - yield LogEvent( + yield TurnStreamLogEvent( role=step_type, content=f"Tool:{t.tool_name} Args:{t.arguments}", color="green", @@ -129,13 +150,13 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non inserted_context = interleaved_content_as_str(r.content) content = f"fetched {len(inserted_context)} bytes from memory" - yield LogEvent( + yield TurnStreamLogEvent( role=step_type, content=content, color="cyan", ) else: - yield LogEvent( + yield TurnStreamLogEvent( role=step_type, content=f"Tool:{r.tool_name} Response:{r.content}", color="green", @@ -154,15 +175,11 @@ def _get_event_type_step_type(self, chunk): return previous_event_type, previous_step_type return None, None - def log(self, event_generator): - previous_event_type = None - previous_step_type = None +class EventLogger: + def log(self, event_generator): + logger = TurnStreamEventLogger() for chunk in event_generator: - for log_event in self._get_log_event( - chunk, previous_event_type, previous_step_type - ): + log_event = logger.process_chunk(chunk) + if log_event: yield log_event - previous_event_type, previous_step_type = self._get_event_type_step_type( - chunk - ) diff --git a/src/llama_stack_client/lib/inference/event_logger.py b/src/llama_stack_client/lib/inference/event_logger.py index 333f4dcc..e0ec3a24 100644 --- a/src/llama_stack_client/lib/inference/event_logger.py +++ b/src/llama_stack_client/lib/inference/event_logger.py @@ -6,7 +6,7 @@ from termcolor import cprint -class LogEvent: +class InferenceStreamLogEvent: def __init__( self, content: str = "", @@ -21,13 +21,22 @@ def print(self, flush=True): cprint(f"{self.content}", color=self.color, end=self.end, flush=flush) +class InferenceStreamLogEventLogger: + def process_chunk(self, chunk): + event = chunk.event + if event.event_type == "start": + return InferenceStreamLogEvent("Assistant> ", color="cyan", end="") + elif event.event_type == "progress": + return InferenceStreamLogEvent(event.delta.text, color="yellow", end="") + elif event.event_type == "complete": + return InferenceStreamLogEvent("") + return None + + class EventLogger: def log(self, event_generator): + logger = InferenceStreamLogEventLogger() for chunk in event_generator: - event = chunk.event - if event.event_type == "start": - yield LogEvent("Assistant> ", color="cyan", end="") - elif event.event_type == "progress": - yield LogEvent(event.delta.text, color="yellow", end="") - elif event.event_type == "complete": - yield LogEvent("") + log_event = logger.process_chunk(chunk) + if log_event: + yield log_event diff --git a/src/llama_stack_client/lib/stream_logger.py b/src/llama_stack_client/lib/stream_logger.py new file mode 100644 index 00000000..4c0ac35c --- /dev/null +++ b/src/llama_stack_client/lib/stream_logger.py @@ -0,0 +1,27 @@ +from .agents.event_logger import TurnStreamEventLogger +from .inference.event_logger import InferenceStreamLogEventLogger + + +class EventLogger: + def log(self, event_generator): + inference_logger = None + turn_logger = None + for chunk in event_generator: + if not hasattr(chunk, "event"): + raise ValueError(f"Unexpected chunk without event: {chunk}") + + event = chunk.event + if hasattr(event, "event_type"): + if not inference_logger: + inference_logger = InferenceStreamLogEventLogger() + log_event = inference_logger.process_chunk(chunk) + if log_event: + yield log_event + elif hasattr(event, "payload") and hasattr(event.payload, "event_type"): + if not turn_logger: + turn_logger = TurnStreamEventLogger() + log_event = turn_logger.process_chunk(chunk) + if log_event: + yield log_event + else: + raise ValueError(f"Unsupported event: {event}") From 8e06e73fdff9f72e2e58bcc3d9367c6fef54f43b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 30 Jan 2025 11:06:37 -0800 Subject: [PATCH 2/4] make this classmethod --- src/llama_stack_client/lib/stream_logger.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llama_stack_client/lib/stream_logger.py b/src/llama_stack_client/lib/stream_logger.py index 4c0ac35c..94b6ce9e 100644 --- a/src/llama_stack_client/lib/stream_logger.py +++ b/src/llama_stack_client/lib/stream_logger.py @@ -2,8 +2,9 @@ from .inference.event_logger import InferenceStreamLogEventLogger -class EventLogger: - def log(self, event_generator): +class EventStreamLogger: + @classmethod + def gen(cls, event_generator): inference_logger = None turn_logger = None for chunk in event_generator: From e5915d9f1703c13ed156ed3f6d1408188db432f2 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 30 Jan 2025 14:35:34 -0800 Subject: [PATCH 3/4] logger -> printer --- .../lib/agents/event_logger.py | 36 +++++++++---------- .../lib/inference/event_logger.py | 20 ++++++----- src/llama_stack_client/lib/stream_logger.py | 28 --------------- src/llama_stack_client/lib/stream_printer.py | 28 +++++++++++++++ 4 files changed, 57 insertions(+), 55 deletions(-) delete mode 100644 src/llama_stack_client/lib/stream_logger.py create mode 100644 src/llama_stack_client/lib/stream_printer.py diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index 68e02538..794ff6ac 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -31,7 +31,7 @@ def _process(c) -> str: return _process(content) -class TurnStreamLogEvent: +class TurnStreamPrintableEvent: def __init__( self, role: Optional[str] = None, @@ -54,7 +54,7 @@ def print(self, flush=True): cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) -class TurnStreamEventLogger: +class TurnStreamEventPrinter: def __init__(self): self.previous_event_type = None self.previous_step_type = None @@ -70,7 +70,7 @@ def process_chunk(self, chunk): def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None): if hasattr(chunk, "error"): - yield TurnStreamLogEvent( + yield TurnStreamPrintableEvent( role=None, content=chunk.error["message"], color="red" ) return @@ -80,7 +80,7 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non # since it does not produce event but instead # a Message if isinstance(chunk, ToolResponseMessage): - yield TurnStreamLogEvent( + yield TurnStreamPrintableEvent( role="CustomTool", content=chunk.content, color="green" ) return @@ -90,7 +90,7 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non if event_type in {"turn_start", "turn_complete"}: # Currently not logging any turn realted info - yield TurnStreamLogEvent(role=None, content="", end="", color="grey") + yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey") return step_type = event.payload.step_type @@ -98,11 +98,11 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non if step_type == "shield_call" and event_type == "step_complete": violation = event.payload.step_details.violation if not violation: - yield TurnStreamLogEvent( + yield TurnStreamPrintableEvent( role=step_type, content="No Violation", color="magenta" ) else: - yield TurnStreamLogEvent( + yield TurnStreamPrintableEvent( role=step_type, content=f"{violation.metadata} {violation.user_message}", color="red", @@ -111,20 +111,20 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non # handle inference if step_type == "inference": if event_type == "step_start": - yield TurnStreamLogEvent( + yield TurnStreamPrintableEvent( role=step_type, content="", end="", color="yellow" ) elif event_type == "step_progress": if event.payload.delta.type == "tool_call": if isinstance(event.payload.delta.tool_call, str): - yield TurnStreamLogEvent( + yield TurnStreamPrintableEvent( role=None, content=event.payload.delta.tool_call, end="", color="cyan", ) elif event.payload.delta.type == "text": - yield TurnStreamLogEvent( + yield TurnStreamPrintableEvent( role=None, content=event.payload.delta.text, end="", @@ -132,14 +132,14 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non ) else: # step complete - yield TurnStreamLogEvent(role=None, content="") + yield TurnStreamPrintableEvent(role=None, content="") # handle tool_execution if step_type == "tool_execution" and event_type == "step_complete": # Only print tool calls and responses at the step_complete event details = event.payload.step_details for t in details.tool_calls: - yield TurnStreamLogEvent( + yield TurnStreamPrintableEvent( role=step_type, content=f"Tool:{t.tool_name} Args:{t.arguments}", color="green", @@ -150,13 +150,13 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non inserted_context = interleaved_content_as_str(r.content) content = f"fetched {len(inserted_context)} bytes from memory" - yield TurnStreamLogEvent( + yield TurnStreamPrintableEvent( role=step_type, content=content, color="cyan", ) else: - yield TurnStreamLogEvent( + yield TurnStreamPrintableEvent( role=step_type, content=f"Tool:{r.tool_name} Response:{r.content}", color="green", @@ -178,8 +178,8 @@ def _get_event_type_step_type(self, chunk): class EventLogger: def log(self, event_generator): - logger = TurnStreamEventLogger() + printer = TurnStreamEventPrinter() for chunk in event_generator: - log_event = logger.process_chunk(chunk) - if log_event: - yield log_event + printable_event = printer.process_chunk(chunk) + if printable_event: + yield printable_event diff --git a/src/llama_stack_client/lib/inference/event_logger.py b/src/llama_stack_client/lib/inference/event_logger.py index e0ec3a24..e3bce0de 100644 --- a/src/llama_stack_client/lib/inference/event_logger.py +++ b/src/llama_stack_client/lib/inference/event_logger.py @@ -6,7 +6,7 @@ from termcolor import cprint -class InferenceStreamLogEvent: +class InferenceStreamPrintableEvent: def __init__( self, content: str = "", @@ -21,22 +21,24 @@ def print(self, flush=True): cprint(f"{self.content}", color=self.color, end=self.end, flush=flush) -class InferenceStreamLogEventLogger: +class InferenceStreamLogEventPrinter: def process_chunk(self, chunk): event = chunk.event if event.event_type == "start": - return InferenceStreamLogEvent("Assistant> ", color="cyan", end="") + return InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="") elif event.event_type == "progress": - return InferenceStreamLogEvent(event.delta.text, color="yellow", end="") + return InferenceStreamPrintableEvent( + event.delta.text, color="yellow", end="" + ) elif event.event_type == "complete": - return InferenceStreamLogEvent("") + return InferenceStreamPrintableEvent("") return None class EventLogger: def log(self, event_generator): - logger = InferenceStreamLogEventLogger() + printer = InferenceStreamLogEventPrinter() for chunk in event_generator: - log_event = logger.process_chunk(chunk) - if log_event: - yield log_event + printable_event = printer.process_chunk(chunk) + if printable_event: + yield printable_event diff --git a/src/llama_stack_client/lib/stream_logger.py b/src/llama_stack_client/lib/stream_logger.py deleted file mode 100644 index 94b6ce9e..00000000 --- a/src/llama_stack_client/lib/stream_logger.py +++ /dev/null @@ -1,28 +0,0 @@ -from .agents.event_logger import TurnStreamEventLogger -from .inference.event_logger import InferenceStreamLogEventLogger - - -class EventStreamLogger: - @classmethod - def gen(cls, event_generator): - inference_logger = None - turn_logger = None - for chunk in event_generator: - if not hasattr(chunk, "event"): - raise ValueError(f"Unexpected chunk without event: {chunk}") - - event = chunk.event - if hasattr(event, "event_type"): - if not inference_logger: - inference_logger = InferenceStreamLogEventLogger() - log_event = inference_logger.process_chunk(chunk) - if log_event: - yield log_event - elif hasattr(event, "payload") and hasattr(event.payload, "event_type"): - if not turn_logger: - turn_logger = TurnStreamEventLogger() - log_event = turn_logger.process_chunk(chunk) - if log_event: - yield log_event - else: - raise ValueError(f"Unsupported event: {event}") diff --git a/src/llama_stack_client/lib/stream_printer.py b/src/llama_stack_client/lib/stream_printer.py new file mode 100644 index 00000000..89e7c0a4 --- /dev/null +++ b/src/llama_stack_client/lib/stream_printer.py @@ -0,0 +1,28 @@ +from .agents.event_logger import TurnStreamEventPrinter +from .inference.event_logger import InferenceStreamLogEventPrinter + + +class EventStreamPrinter: + @classmethod + def gen(cls, event_generator): + inference_printer = None + turn_printer = None + for chunk in event_generator: + if not hasattr(chunk, "event"): + raise ValueError(f"Unexpected chunk without event: {chunk}") + + event = chunk.event + if hasattr(event, "event_type"): + if not inference_printer: + inference_printer = InferenceStreamLogEventPrinter() + printable_event = inference_printer.process_chunk(chunk) + if printable_event: + yield printable_event + elif hasattr(event, "payload") and hasattr(event.payload, "event_type"): + if not turn_printer: + turn_printer = TurnStreamEventPrinter() + printable_event = turn_printer.process_chunk(chunk) + if printable_event: + yield printable_event + else: + raise ValueError(f"Unsupported event: {event}") From 41cac2c5c6bbc9ba9f5bb99736834a56dab321cf Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 30 Jan 2025 14:49:54 -0800 Subject: [PATCH 4/4] naming changes, bug fix after testing agents test --- .../lib/agents/event_logger.py | 17 +++++++++-------- .../lib/inference/event_logger.py | 13 +++++-------- src/llama_stack_client/lib/stream_printer.py | 8 ++------ 3 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index 794ff6ac..d20eeb38 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -59,16 +59,19 @@ def __init__(self): self.previous_event_type = None self.previous_step_type = None - def process_chunk(self, chunk): - log_event = self._get_log_event( + def yield_printable_events(self, chunk): + for printable_event in self._yield_printable_events( chunk, self.previous_event_type, self.previous_step_type - ) + ): + yield printable_event + self.previous_event_type, self.previous_step_type = ( self._get_event_type_step_type(chunk) ) - return log_event - def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None): + def _yield_printable_events( + self, chunk, previous_event_type=None, previous_step_type=None + ): if hasattr(chunk, "error"): yield TurnStreamPrintableEvent( role=None, content=chunk.error["message"], color="red" @@ -180,6 +183,4 @@ class EventLogger: def log(self, event_generator): printer = TurnStreamEventPrinter() for chunk in event_generator: - printable_event = printer.process_chunk(chunk) - if printable_event: - yield printable_event + yield from printer.yield_printable_events(chunk) diff --git a/src/llama_stack_client/lib/inference/event_logger.py b/src/llama_stack_client/lib/inference/event_logger.py index e3bce0de..18e97ad2 100644 --- a/src/llama_stack_client/lib/inference/event_logger.py +++ b/src/llama_stack_client/lib/inference/event_logger.py @@ -22,23 +22,20 @@ def print(self, flush=True): class InferenceStreamLogEventPrinter: - def process_chunk(self, chunk): + def yield_printable_events(self, chunk): event = chunk.event if event.event_type == "start": - return InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="") + yield InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="") elif event.event_type == "progress": - return InferenceStreamPrintableEvent( + yield InferenceStreamPrintableEvent( event.delta.text, color="yellow", end="" ) elif event.event_type == "complete": - return InferenceStreamPrintableEvent("") - return None + yield InferenceStreamPrintableEvent("") class EventLogger: def log(self, event_generator): printer = InferenceStreamLogEventPrinter() for chunk in event_generator: - printable_event = printer.process_chunk(chunk) - if printable_event: - yield printable_event + yield from printer.yield_printable_events(chunk) diff --git a/src/llama_stack_client/lib/stream_printer.py b/src/llama_stack_client/lib/stream_printer.py index 89e7c0a4..a08d9663 100644 --- a/src/llama_stack_client/lib/stream_printer.py +++ b/src/llama_stack_client/lib/stream_printer.py @@ -15,14 +15,10 @@ def gen(cls, event_generator): if hasattr(event, "event_type"): if not inference_printer: inference_printer = InferenceStreamLogEventPrinter() - printable_event = inference_printer.process_chunk(chunk) - if printable_event: - yield printable_event + yield from inference_printer.yield_printable_events(chunk) elif hasattr(event, "payload") and hasattr(event.payload, "event_type"): if not turn_printer: turn_printer = TurnStreamEventPrinter() - printable_event = turn_printer.process_chunk(chunk) - if printable_event: - yield printable_event + yield from turn_printer.yield_printable_events(chunk) else: raise ValueError(f"Unsupported event: {event}")