Skip to content

Commit ef642f2

Browse files
committed
llc enhance
1 parent f6d433b commit ef642f2

File tree

4 files changed

+127
-49
lines changed

4 files changed

+127
-49
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ venv.bak/
3535
/.vscode
3636
/output
3737
dist/
38-
.coda/
38+
.coda/
39+
.DS_Store

cozeloop/integration/langchain/trace_callback.py

Lines changed: 89 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55
import json
6+
import threading
67
import time
78
import traceback
89
from typing import List, Dict, Union, Any, Optional, Callable, Protocol
@@ -24,8 +25,6 @@
2425
from cozeloop.integration.langchain.trace_model.runtime import RuntimeInfo
2526
from cozeloop.integration.langchain.util import calc_token_usage, get_prompt_tag
2627

27-
_trace_callback_client: Optional[Client] = None
28-
2928

3029
class LoopTracer:
3130
@classmethod
@@ -35,19 +34,28 @@ def get_callback_handler(
3534
modify_name_fn: Optional[Callable[[str], str]] = None,
3635
add_tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
3736
tags: Dict[str, Any] = None,
37+
child_of: Optional[Span] = None,
38+
state_span_ctx_key: str = None,
3839
):
3940
"""
4041
Do not hold it for a long time, get a new callback_handler for each request.
41-
modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
42-
add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
42+
client: cozeloop client instance. If not provided, use the default client.
43+
modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
44+
add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
45+
It's priority higher than parameter tags.
46+
tags: default tags dict. It's priority lower than parameter add_tags_fn.
47+
child_of: parent span of this callback_handler.
48+
state_span_ctx_key: span context field name in state. If provided, you need set the field in sate, and we will use it to set span context in state.
49+
You can get it from state for creating inner span in async node.
4350
"""
44-
global _trace_callback_client
45-
if client:
46-
_trace_callback_client = client
47-
else:
48-
_trace_callback_client = get_default_client()
49-
50-
return LoopTraceCallbackHandler(modify_name_fn, add_tags_fn, tags)
51+
return LoopTraceCallbackHandler(
52+
name_fn=modify_name_fn,
53+
tags_fn=add_tags_fn,
54+
tags=tags,
55+
child_of=child_of,
56+
client=client,
57+
state_span_ctx_key=state_span_ctx_key,
58+
)
5159

5260

5361
class LoopTraceCallbackHandler(BaseCallbackHandler):
@@ -56,13 +64,21 @@ def __init__(
5664
name_fn: Optional[Callable[[str], str]] = None,
5765
tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
5866
tags: Dict[str, Any] = None,
67+
child_of: Optional[Span] = None,
68+
client: Client = None,
69+
state_span_ctx_key: str = None,
5970
):
6071
super().__init__()
61-
self._space_id = _trace_callback_client.workspace_id
72+
self._client = client if client else get_default_client()
73+
self._space_id = self._client.workspace_id
6274
self.run_map: Dict[str, Run] = {}
6375
self.name_fn = name_fn
6476
self.tags_fn = tags_fn
6577
self._tags = tags if tags else {}
78+
self.trace_id: Optional[str] = None
79+
self._id_lock = threading.Lock()
80+
self._child_of = child_of
81+
self._state_span_ctx_key = state_span_ctx_key
6682

6783
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
6884
span_tags = {}
@@ -73,14 +89,16 @@ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs:
7389
span_tags['input'] = ModelTraceInput([BaseMessage(type='', content=prompt) for prompt in prompts],
7490
kwargs.get('invocation_params', {})).to_json()
7591
except Exception as e:
76-
flow_span.set_error(e)
92+
span_tags['internal_error'] = repr(e)
93+
span_tags['internal_error_trace'] = traceback.format_exc()
7794
finally:
7895
span_tags.update(_get_model_span_tags(**kwargs))
7996
self._set_span_tags(flow_span, span_tags)
8097
# Store some pre-aspect information.
8198
self.run_map[str(kwargs['run_id'])].model_meta = ModelMeta(
8299
message=[{'role': '', 'content': prompt} for prompt in prompts],
83100
model_name=span_tags.get('model_name', ''))
101+
return flow_span
84102

85103
def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any) -> Any:
86104
span_tags = {}
@@ -90,14 +108,16 @@ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[Ba
90108
try:
91109
span_tags['input'] = ModelTraceInput(messages, kwargs.get('invocation_params', {})).to_json()
92110
except Exception as e:
93-
flow_span.set_error(e)
111+
span_tags['internal_error'] = repr(e)
112+
span_tags['internal_error_trace'] = traceback.format_exc()
94113
finally:
95114
span_tags.update(_get_model_span_tags(**kwargs))
96115
self._set_span_tags(flow_span, span_tags)
97116
# Store some pre-aspect information.
98117
self.run_map[str(kwargs['run_id'])].model_meta = (
99118
ModelMeta(message=[{'role': message.type, 'content': message.content} for inner_messages in messages for
100119
message in inner_messages], model_name=span_tags.get('model_name', '')))
120+
return flow_span
101121

102122
async def on_llm_new_token(self, token: str, *, chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
103123
**kwargs: Any) -> None:
@@ -119,10 +139,14 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
119139
if tags:
120140
self._set_span_tags(flow_span, tags, need_convert_tag_value=False)
121141
except Exception as e:
122-
flow_span.set_error(e)
142+
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
143+
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
123144
# finish flow_span
124145
self._end_flow_span(flow_span)
125146

147+
def on_llm_error(self, error: Exception, **kwargs: Any) -> Any:
148+
self.on_chain_error(error, **kwargs)
149+
126150
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
127151
flow_span = None
128152
try:
@@ -131,13 +155,27 @@ def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **k
131155
self._on_prompt_start(flow_span, serialized, inputs, **kwargs)
132156
else:
133157
span_type = 'chain'
134-
if kwargs['name'] == 'LangGraph': # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent
158+
if kwargs[
159+
'name'] == 'LangGraph': # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent
135160
span_type = 'graph'
136161
flow_span = self._new_flow_span(kwargs['name'], span_type, **kwargs)
137162
flow_span.set_tags({'input': _convert_2_json(inputs)})
138163
except Exception as e:
139164
if flow_span is not None:
140-
flow_span.set_error(e)
165+
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
166+
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
167+
finally:
168+
if flow_span is not None:
169+
# set trace_id
170+
with self._id_lock:
171+
if hasattr(flow_span, 'context'):
172+
self.trace_id = flow_span.context.trace_id
173+
else:
174+
self.trace_id = flow_span.trace_id
175+
# set span_ctx in state
176+
if self._state_span_ctx_key:
177+
inputs[self._state_span_ctx_key] = flow_span
178+
return flow_span
141179

142180
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> Any:
143181
flow_span = self.run_map[str(kwargs['run_id'])].span
@@ -151,16 +189,17 @@ def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> An
151189
else:
152190
flow_span.set_tags({'output': _convert_2_json(outputs)})
153191
except Exception as e:
154-
flow_span.set_error(e)
192+
if flow_span:
193+
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
194+
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
155195
self._end_flow_span(flow_span)
156196

157197
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
158198
flow_span = self._get_flow_span(**kwargs)
159199
if flow_span is None:
160200
span_name = '_Exception' if isinstance(error, Exception) else '_KeyboardInterrupt'
161201
flow_span = self._new_flow_span(span_name, 'chain_error', **kwargs)
162-
flow_span.set_error(error)
163-
flow_span.set_tags({'error_trace': traceback.format_exc()})
202+
flow_span.set_tags({'error': repr(error), 'error_trace': traceback.format_exc()})
164203
self._end_flow_span(flow_span)
165204

166205
def on_tool_start(
@@ -170,13 +209,15 @@ def on_tool_start(
170209
span_name = serialized.get('name', 'unknown')
171210
flow_span = self._new_flow_span(span_name, 'tool', **kwargs)
172211
self._set_span_tags(flow_span, span_tags)
212+
return flow_span
173213

174214
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
175215
flow_span = self._get_flow_span(**kwargs)
176216
try:
177217
flow_span.set_tags({'output': _convert_2_json(output)})
178218
except Exception as e:
179-
flow_span.set_error(e)
219+
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
220+
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
180221
self._end_flow_span(flow_span)
181222

182223
def on_tool_error(
@@ -186,8 +227,8 @@ def on_tool_error(
186227
if flow_span is None:
187228
span_name = '_Exception' if isinstance(error, Exception) else '_KeyboardInterrupt'
188229
flow_span = self._new_flow_span(span_name, 'tool_error', **kwargs)
189-
flow_span.set_error(error)
190-
flow_span.set_tags({'error_trace': traceback.format_exc()})
230+
span_tags = {'error': repr(error), 'error_trace': traceback.format_exc()}
231+
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
191232
self._end_flow_span(flow_span)
192233

193234
def on_text(self, text: str, **kwargs: Any) -> Any:
@@ -200,7 +241,8 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
200241
return
201242

202243
def _end_flow_span(self, span: Span):
203-
span.finish()
244+
if span:
245+
span.finish()
204246

205247
def _get_model_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, Any]:
206248
return self._get_model_token_tags(response, **kwargs)
@@ -224,20 +266,25 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str,
224266
result['input_cached_tokens'] = input_cached_tokens
225267
elif response.generations is not None and len(response.generations) > 0 and response.generations[0] is not None:
226268
for i, generation in enumerate(response.generations[0]):
227-
if isinstance(generation, ChatGeneration) and isinstance(generation.message,(AIMessageChunk, AIMessage)) and generation.message.usage_metadata:
269+
if isinstance(generation, ChatGeneration) and isinstance(generation.message, (
270+
AIMessageChunk, AIMessage)) and generation.message.usage_metadata:
228271
is_get_from_langchain = True
229272
result['input_tokens'] = generation.message.usage_metadata.get('input_tokens', 0)
230273
result['output_tokens'] = generation.message.usage_metadata.get('output_tokens', 0)
231274
result['tokens'] = result['input_tokens'] + result['output_tokens']
232275
if generation.message.usage_metadata.get('output_token_details', {}):
233-
reasoning_tokens = generation.message.usage_metadata.get('output_token_details', {}).get('reasoning', 0)
276+
reasoning_tokens = generation.message.usage_metadata.get('output_token_details', {}).get(
277+
'reasoning', 0)
234278
if reasoning_tokens:
235279
result['reasoning_tokens'] = reasoning_tokens
236280
if generation.message.usage_metadata.get('input_token_details', {}):
237-
input_read_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_read', 0)
281+
input_read_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get(
282+
'cache_read', 0)
238283
if input_read_cached_tokens:
239284
result['input_cached_tokens'] = input_read_cached_tokens
240-
input_creation_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_creation', 0)
285+
input_creation_cached_tokens = generation.message.usage_metadata.get('input_token_details',
286+
{}).get('cache_creation',
287+
0)
241288
if input_creation_cached_tokens:
242289
result['input_creation_cached_tokens'] = input_creation_cached_tokens
243290
if is_get_from_langchain:
@@ -259,7 +306,8 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str,
259306
span_tags = {'error_info': repr(e), 'error_trace': traceback.format_exc()}
260307
return span_tags
261308

262-
def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str), **kwargs: Any) -> None:
309+
def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str),
310+
**kwargs: Any) -> None:
263311
# get inputs
264312
params: List[Argument] = []
265313
if isinstance(inputs, str):
@@ -311,6 +359,9 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
311359
parent_span: Span = None
312360
if 'parent_run_id' in kwargs and kwargs['parent_run_id'] is not None and str(kwargs['parent_run_id']) in self.run_map:
313361
parent_span = self.run_map[str(kwargs['parent_run_id'])].span
362+
# only root span use child_of
363+
if parent_span is None and self._child_of:
364+
parent_span = self._child_of
314365
# modify name
315366
error_tag = {}
316367
try:
@@ -321,15 +372,15 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
321372
except Exception as e:
322373
error_tag = {'error_info': f'name_fn error {repr(e)}', 'error_trace': traceback.format_exc()}
323374
# new span
324-
flow_span = _trace_callback_client.start_span(span_name, span_type, child_of=parent_span)
375+
flow_span = self._client.start_span(span_name, span_type, child_of=parent_span)
325376
run_id = str(kwargs['run_id'])
326377
self.run_map[run_id] = Run(run_id, flow_span, span_type)
327378
# set runtime
328379
flow_span.set_runtime(RuntimeInfo())
329380
# set extra tags
330-
flow_span.set_tags(self._tags) # global tags
381+
flow_span.set_tags(self._tags) # global tags
331382
try:
332-
if self.tags_fn: # add tags fn
383+
if self.tags_fn: # add tags fn
333384
tags = self.tags_fn(node_name)
334385
if isinstance(tags, dict):
335386
flow_span.set_tags(tags)
@@ -365,7 +416,10 @@ def _set_extra_span_tags(self, flow_span: Span, tag_list: list, **kwargs: Any):
365416
class Run:
366417
def __init__(self, run_id: str, span: Span, span_type: str) -> None:
367418
self.run_id = run_id # langchain run_id
368-
self.span_id = span.span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping.
419+
if hasattr(span, 'context'):
420+
self.span_id = span.context.span_id
421+
else:
422+
self.span_id = span.span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping.
369423
self.span = span
370424
self.span_type = span_type
371425
self.child_runs: List[Run] = Field(default_factory=list)
@@ -519,7 +573,8 @@ def _convert_inputs(inputs: Any) -> Any:
519573
format_inputs['content'] = inputs.content
520574
return format_inputs
521575
if isinstance(inputs, BaseMessage):
522-
message = Message(role=inputs.type, content=inputs.content, tool_calls=inputs.additional_kwargs.get('tool_calls', []))
576+
message = Message(role=inputs.type, content=inputs.content,
577+
tool_calls=inputs.additional_kwargs.get('tool_calls', []))
523578
return message
524579
if isinstance(inputs, ChatPromptValue):
525580
return _convert_inputs(inputs.messages)

0 commit comments

Comments
 (0)