Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 41 additions & 24 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _process(c) -> str:
return _process(content)


class LogEvent:
class TurnStreamPrintableEvent:
def __init__(
self,
role: Optional[str] = None,
Expand All @@ -54,36 +54,55 @@ def print(self, flush=True):
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)


class EventLogger:
class TurnStreamEventPrinter:
def __init__(self):
self.previous_event_type = None
self.previous_step_type = None

def process_chunk(self, chunk):
log_event = self._get_log_event(
chunk, self.previous_event_type, self.previous_step_type
)
self.previous_event_type, self.previous_step_type = (
self._get_event_type_step_type(chunk)
)
return log_event

def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None):
if hasattr(chunk, "error"):
yield LogEvent(role=None, content=chunk.error["message"], color="red")
yield TurnStreamPrintableEvent(
role=None, content=chunk.error["message"], color="red"
)
return

if not hasattr(chunk, "event"):
# Need to check for custom tool first
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield LogEvent(role="CustomTool", content=chunk.content, color="green")
yield TurnStreamPrintableEvent(
role="CustomTool", content=chunk.content, color="green"
)
return

event = chunk.event
event_type = event.payload.event_type

if event_type in {"turn_start", "turn_complete"}:
# Currently not logging any turn realted info
yield LogEvent(role=None, content="", end="", color="grey")
yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey")
return

step_type = event.payload.step_type
# handle safety
if step_type == "shield_call" and event_type == "step_complete":
violation = event.payload.step_details.violation
if not violation:
yield LogEvent(role=step_type, content="No Violation", color="magenta")
yield TurnStreamPrintableEvent(
role=step_type, content="No Violation", color="magenta"
)
else:
yield LogEvent(
yield TurnStreamPrintableEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
Expand All @@ -92,33 +111,35 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non
# handle inference
if step_type == "inference":
if event_type == "step_start":
yield LogEvent(role=step_type, content="", end="", color="yellow")
yield TurnStreamPrintableEvent(
role=step_type, content="", end="", color="yellow"
)
elif event_type == "step_progress":
if event.payload.delta.type == "tool_call":
if isinstance(event.payload.delta.tool_call, str):
yield LogEvent(
yield TurnStreamPrintableEvent(
role=None,
content=event.payload.delta.tool_call,
end="",
color="cyan",
)
elif event.payload.delta.type == "text":
yield LogEvent(
yield TurnStreamPrintableEvent(
role=None,
content=event.payload.delta.text,
end="",
color="yellow",
)
else:
# step complete
yield LogEvent(role=None, content="")
yield TurnStreamPrintableEvent(role=None, content="")

# handle tool_execution
if step_type == "tool_execution" and event_type == "step_complete":
# Only print tool calls and responses at the step_complete event
details = event.payload.step_details
for t in details.tool_calls:
yield LogEvent(
yield TurnStreamPrintableEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
Expand All @@ -129,13 +150,13 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non
inserted_context = interleaved_content_as_str(r.content)
content = f"fetched {len(inserted_context)} bytes from memory"

yield LogEvent(
yield TurnStreamPrintableEvent(
role=step_type,
content=content,
color="cyan",
)
else:
yield LogEvent(
yield TurnStreamPrintableEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
Expand All @@ -154,15 +175,11 @@ def _get_event_type_step_type(self, chunk):
return previous_event_type, previous_step_type
return None, None

def log(self, event_generator):
previous_event_type = None
previous_step_type = None

class EventLogger:
def log(self, event_generator):
printer = TurnStreamEventPrinter()
for chunk in event_generator:
for log_event in self._get_log_event(
chunk, previous_event_type, previous_step_type
):
yield log_event
previous_event_type, previous_step_type = self._get_event_type_step_type(
chunk
)
printable_event = printer.process_chunk(chunk)
if printable_event:
yield printable_event
27 changes: 19 additions & 8 deletions src/llama_stack_client/lib/inference/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from termcolor import cprint


class LogEvent:
class InferenceStreamPrintableEvent:
def __init__(
self,
content: str = "",
Expand All @@ -21,13 +21,24 @@ def print(self, flush=True):
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)


class InferenceStreamLogEventPrinter:
def process_chunk(self, chunk):
event = chunk.event
if event.event_type == "start":
return InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="")
elif event.event_type == "progress":
return InferenceStreamPrintableEvent(
event.delta.text, color="yellow", end=""
)
elif event.event_type == "complete":
return InferenceStreamPrintableEvent("")
return None


class EventLogger:
def log(self, event_generator):
printer = InferenceStreamLogEventPrinter()
for chunk in event_generator:
event = chunk.event
if event.event_type == "start":
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == "progress":
yield LogEvent(event.delta.text, color="yellow", end="")
elif event.event_type == "complete":
yield LogEvent("")
printable_event = printer.process_chunk(chunk)
if printable_event:
yield printable_event
28 changes: 28 additions & 0 deletions src/llama_stack_client/lib/stream_printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from .agents.event_logger import TurnStreamEventPrinter
from .inference.event_logger import InferenceStreamLogEventPrinter


class EventStreamPrinter:
@classmethod
def gen(cls, event_generator):
inference_printer = None
turn_printer = None
for chunk in event_generator:
if not hasattr(chunk, "event"):
raise ValueError(f"Unexpected chunk without event: {chunk}")

event = chunk.event
if hasattr(event, "event_type"):
if not inference_printer:
inference_printer = InferenceStreamLogEventPrinter()
printable_event = inference_printer.process_chunk(chunk)
if printable_event:
yield printable_event
elif hasattr(event, "payload") and hasattr(event.payload, "event_type"):
if not turn_printer:
turn_printer = TurnStreamEventPrinter()
printable_event = turn_printer.process_chunk(chunk)
if printable_event:
yield printable_event
else:
raise ValueError(f"Unsupported event: {event}")