Skip to content

Commit 28b93d0

Browse files
authored
Unify client SDK loggers (#106)
Having two loggers is confusing. Let's have one. This change is backward compatible. Older code will continue to work. API is ```python for event in EventStreamPrinter.gen(inference.chat_completion(...): event.print() ``` ## Test Plan ```bash $ llama-stack-client inference chat-completion --message "Hi how are you" --stream Assistant> I'm just a language model, I don't have emotions or feelings like humans do, but I'm functioning properly and ready to help with any questions or tasks you may have! How about you? How's your day going so far? $ ``` ```bash cd llama-stack/tests/client-sdk LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -s -v agents/test_agents.py -k agent_simple ```
1 parent 0be2852 commit 28b93d0

File tree

3 files changed

+88
-32
lines changed

3 files changed

+88
-32
lines changed

src/llama_stack_client/lib/agents/event_logger.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _process(c) -> str:
3131
return _process(content)
3232

3333

34-
class LogEvent:
34+
class TurnStreamPrintableEvent:
3535
def __init__(
3636
self,
3737
role: Optional[str] = None,
@@ -54,36 +54,55 @@ def print(self, flush=True):
5454
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
5555

5656

57-
class EventLogger:
57+
class TurnStreamEventPrinter:
58+
def __init__(self):
59+
self.previous_event_type = None
60+
self.previous_step_type = None
61+
62+
def process_chunk(self, chunk):
63+
log_event = self._get_log_event(
64+
chunk, self.previous_event_type, self.previous_step_type
65+
)
66+
self.previous_event_type, self.previous_step_type = (
67+
self._get_event_type_step_type(chunk)
68+
)
69+
return log_event
70+
5871
def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None):
5972
if hasattr(chunk, "error"):
60-
yield LogEvent(role=None, content=chunk.error["message"], color="red")
73+
yield TurnStreamPrintableEvent(
74+
role=None, content=chunk.error["message"], color="red"
75+
)
6176
return
6277

6378
if not hasattr(chunk, "event"):
6479
# Need to check for custom tool first
6580
# since it does not produce event but instead
6681
# a Message
6782
if isinstance(chunk, ToolResponseMessage):
68-
yield LogEvent(role="CustomTool", content=chunk.content, color="green")
83+
yield TurnStreamPrintableEvent(
84+
role="CustomTool", content=chunk.content, color="green"
85+
)
6986
return
7087

7188
event = chunk.event
7289
event_type = event.payload.event_type
7390

7491
if event_type in {"turn_start", "turn_complete"}:
7592
# Currently not logging any turn realted info
76-
yield LogEvent(role=None, content="", end="", color="grey")
93+
yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey")
7794
return
7895

7996
step_type = event.payload.step_type
8097
# handle safety
8198
if step_type == "shield_call" and event_type == "step_complete":
8299
violation = event.payload.step_details.violation
83100
if not violation:
84-
yield LogEvent(role=step_type, content="No Violation", color="magenta")
101+
yield TurnStreamPrintableEvent(
102+
role=step_type, content="No Violation", color="magenta"
103+
)
85104
else:
86-
yield LogEvent(
105+
yield TurnStreamPrintableEvent(
87106
role=step_type,
88107
content=f"{violation.metadata} {violation.user_message}",
89108
color="red",
@@ -92,33 +111,35 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non
92111
# handle inference
93112
if step_type == "inference":
94113
if event_type == "step_start":
95-
yield LogEvent(role=step_type, content="", end="", color="yellow")
114+
yield TurnStreamPrintableEvent(
115+
role=step_type, content="", end="", color="yellow"
116+
)
96117
elif event_type == "step_progress":
97118
if event.payload.delta.type == "tool_call":
98119
if isinstance(event.payload.delta.tool_call, str):
99-
yield LogEvent(
120+
yield TurnStreamPrintableEvent(
100121
role=None,
101122
content=event.payload.delta.tool_call,
102123
end="",
103124
color="cyan",
104125
)
105126
elif event.payload.delta.type == "text":
106-
yield LogEvent(
127+
yield TurnStreamPrintableEvent(
107128
role=None,
108129
content=event.payload.delta.text,
109130
end="",
110131
color="yellow",
111132
)
112133
else:
113134
# step complete
114-
yield LogEvent(role=None, content="")
135+
yield TurnStreamPrintableEvent(role=None, content="")
115136

116137
# handle tool_execution
117138
if step_type == "tool_execution" and event_type == "step_complete":
118139
# Only print tool calls and responses at the step_complete event
119140
details = event.payload.step_details
120141
for t in details.tool_calls:
121-
yield LogEvent(
142+
yield TurnStreamPrintableEvent(
122143
role=step_type,
123144
content=f"Tool:{t.tool_name} Args:{t.arguments}",
124145
color="green",
@@ -129,13 +150,13 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non
129150
inserted_context = interleaved_content_as_str(r.content)
130151
content = f"fetched {len(inserted_context)} bytes from memory"
131152

132-
yield LogEvent(
153+
yield TurnStreamPrintableEvent(
133154
role=step_type,
134155
content=content,
135156
color="cyan",
136157
)
137158
else:
138-
yield LogEvent(
159+
yield TurnStreamPrintableEvent(
139160
role=step_type,
140161
content=f"Tool:{r.tool_name} Response:{r.content}",
141162
color="green",
@@ -154,15 +175,11 @@ def _get_event_type_step_type(self, chunk):
154175
return previous_event_type, previous_step_type
155176
return None, None
156177

157-
def log(self, event_generator):
158-
previous_event_type = None
159-
previous_step_type = None
160178

179+
class EventLogger:
180+
def log(self, event_generator):
181+
printer = TurnStreamEventPrinter()
161182
for chunk in event_generator:
162-
for log_event in self._get_log_event(
163-
chunk, previous_event_type, previous_step_type
164-
):
165-
yield log_event
166-
previous_event_type, previous_step_type = self._get_event_type_step_type(
167-
chunk
168-
)
183+
printable_event = printer.process_chunk(chunk)
184+
if printable_event:
185+
yield printable_event

src/llama_stack_client/lib/inference/event_logger.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from termcolor import cprint
77

88

9-
class LogEvent:
9+
class InferenceStreamPrintableEvent:
1010
def __init__(
1111
self,
1212
content: str = "",
@@ -21,13 +21,24 @@ def print(self, flush=True):
2121
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
2222

2323

24+
class InferenceStreamLogEventPrinter:
25+
def process_chunk(self, chunk):
26+
event = chunk.event
27+
if event.event_type == "start":
28+
return InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="")
29+
elif event.event_type == "progress":
30+
return InferenceStreamPrintableEvent(
31+
event.delta.text, color="yellow", end=""
32+
)
33+
elif event.event_type == "complete":
34+
return InferenceStreamPrintableEvent("")
35+
return None
36+
37+
2438
class EventLogger:
2539
def log(self, event_generator):
40+
printer = InferenceStreamLogEventPrinter()
2641
for chunk in event_generator:
27-
event = chunk.event
28-
if event.event_type == "start":
29-
yield LogEvent("Assistant> ", color="cyan", end="")
30-
elif event.event_type == "progress":
31-
yield LogEvent(event.delta.text, color="yellow", end="")
32-
elif event.event_type == "complete":
33-
yield LogEvent("")
42+
printable_event = printer.process_chunk(chunk)
43+
if printable_event:
44+
yield printable_event
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from .agents.event_logger import TurnStreamEventPrinter
2+
from .inference.event_logger import InferenceStreamLogEventPrinter
3+
4+
5+
class EventStreamPrinter:
6+
@classmethod
7+
def gen(cls, event_generator):
8+
inference_printer = None
9+
turn_printer = None
10+
for chunk in event_generator:
11+
if not hasattr(chunk, "event"):
12+
raise ValueError(f"Unexpected chunk without event: {chunk}")
13+
14+
event = chunk.event
15+
if hasattr(event, "event_type"):
16+
if not inference_printer:
17+
inference_printer = InferenceStreamLogEventPrinter()
18+
printable_event = inference_printer.process_chunk(chunk)
19+
if printable_event:
20+
yield printable_event
21+
elif hasattr(event, "payload") and hasattr(event.payload, "event_type"):
22+
if not turn_printer:
23+
turn_printer = TurnStreamEventPrinter()
24+
printable_event = turn_printer.process_chunk(chunk)
25+
if printable_event:
26+
yield printable_event
27+
else:
28+
raise ValueError(f"Unsupported event: {event}")

0 commit comments

Comments
 (0)