33
44from __future__ import annotations
55import json
6+ import threading
67import time
78import traceback
89from typing import List , Dict , Union , Any , Optional , Callable , Protocol
2425from cozeloop .integration .langchain .trace_model .runtime import RuntimeInfo
2526from cozeloop .integration .langchain .util import calc_token_usage , get_prompt_tag
2627
27- _trace_callback_client : Optional [Client ] = None
28-
2928
3029class LoopTracer :
3130 @classmethod
@@ -35,19 +34,28 @@ def get_callback_handler(
3534 modify_name_fn : Optional [Callable [[str ], str ]] = None ,
3635 add_tags_fn : Optional [Callable [[str ], Dict [str , Any ]]] = None ,
3736 tags : Dict [str , Any ] = None ,
37+ child_of : Optional [Span ] = None ,
38+ state_span_ctx_key : str = None ,
3839 ):
3940 """
4041 Do not hold it for a long time, get a new callback_handler for each request.
41- modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
42- add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
42+ client: cozeloop client instance. If not provided, use the default client.
43+ modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
44+ add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
45+ It's priority higher than parameter tags.
46+ tags: default tags dict. It's priority lower than parameter add_tags_fn.
47+ child_of: parent span of this callback_handler.
48+ state_span_ctx_key: span context field name in state. If provided, you need set the field in sate, and we will use it to set span context in state.
49+ You can get it from state for creating inner span in async node.
4350 """
44- global _trace_callback_client
45- if client :
46- _trace_callback_client = client
47- else :
48- _trace_callback_client = get_default_client ()
49-
50- return LoopTraceCallbackHandler (modify_name_fn , add_tags_fn , tags )
51+ return LoopTraceCallbackHandler (
52+ name_fn = modify_name_fn ,
53+ tags_fn = add_tags_fn ,
54+ tags = tags ,
55+ child_of = child_of ,
56+ client = client ,
57+ state_span_ctx_key = state_span_ctx_key ,
58+ )
5159
5260
5361class LoopTraceCallbackHandler (BaseCallbackHandler ):
@@ -56,13 +64,21 @@ def __init__(
5664 name_fn : Optional [Callable [[str ], str ]] = None ,
5765 tags_fn : Optional [Callable [[str ], Dict [str , Any ]]] = None ,
5866 tags : Dict [str , Any ] = None ,
67+ child_of : Optional [Span ] = None ,
68+ client : Client = None ,
69+ state_span_ctx_key : str = None ,
5970 ):
6071 super ().__init__ ()
61- self ._space_id = _trace_callback_client .workspace_id
72+ self ._client = client if client else get_default_client ()
73+ self ._space_id = self ._client .workspace_id
6274 self .run_map : Dict [str , Run ] = {}
6375 self .name_fn = name_fn
6476 self .tags_fn = tags_fn
6577 self ._tags = tags if tags else {}
78+ self .trace_id : Optional [str ] = None
79+ self ._id_lock = threading .Lock ()
80+ self ._child_of = child_of
81+ self ._state_span_ctx_key = state_span_ctx_key
6682
6783 def on_llm_start (self , serialized : Dict [str , Any ], prompts : List [str ], ** kwargs : Any ) -> Any :
6884 span_tags = {}
@@ -73,14 +89,16 @@ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs:
7389 span_tags ['input' ] = ModelTraceInput ([BaseMessage (type = '' , content = prompt ) for prompt in prompts ],
7490 kwargs .get ('invocation_params' , {})).to_json ()
7591 except Exception as e :
76- flow_span .set_error (e )
92+ span_tags ['internal_error' ] = repr (e )
93+ span_tags ['internal_error_trace' ] = traceback .format_exc ()
7794 finally :
7895 span_tags .update (_get_model_span_tags (** kwargs ))
7996 self ._set_span_tags (flow_span , span_tags )
8097 # Store some pre-aspect information.
8198 self .run_map [str (kwargs ['run_id' ])].model_meta = ModelMeta (
8299 message = [{'role' : '' , 'content' : prompt } for prompt in prompts ],
83100 model_name = span_tags .get ('model_name' , '' ))
101+ return flow_span
84102
85103 def on_chat_model_start (self , serialized : Dict [str , Any ], messages : List [List [BaseMessage ]], ** kwargs : Any ) -> Any :
86104 span_tags = {}
@@ -90,14 +108,16 @@ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[Ba
90108 try :
91109 span_tags ['input' ] = ModelTraceInput (messages , kwargs .get ('invocation_params' , {})).to_json ()
92110 except Exception as e :
93- flow_span .set_error (e )
111+ span_tags ['internal_error' ] = repr (e )
112+ span_tags ['internal_error_trace' ] = traceback .format_exc ()
94113 finally :
95114 span_tags .update (_get_model_span_tags (** kwargs ))
96115 self ._set_span_tags (flow_span , span_tags )
97116 # Store some pre-aspect information.
98117 self .run_map [str (kwargs ['run_id' ])].model_meta = (
99118 ModelMeta (message = [{'role' : message .type , 'content' : message .content } for inner_messages in messages for
100119 message in inner_messages ], model_name = span_tags .get ('model_name' , '' )))
120+ return flow_span
101121
102122 async def on_llm_new_token (self , token : str , * , chunk : Optional [Union [GenerationChunk , ChatGenerationChunk ]] = None ,
103123 ** kwargs : Any ) -> None :
@@ -119,10 +139,14 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
119139 if tags :
120140 self ._set_span_tags (flow_span , tags , need_convert_tag_value = False )
121141 except Exception as e :
122- flow_span .set_error (e )
142+ span_tags = {"internal_error" : repr (e ), 'internal_error_trace' : traceback .format_exc ()}
143+ self ._set_span_tags (flow_span , span_tags , need_convert_tag_value = False )
123144 # finish flow_span
124145 self ._end_flow_span (flow_span )
125146
147+ def on_llm_error (self , error : Exception , ** kwargs : Any ) -> Any :
148+ self .on_chain_error (error , ** kwargs )
149+
126150 def on_chain_start (self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any ) -> Any :
127151 flow_span = None
128152 try :
@@ -131,13 +155,27 @@ def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **k
131155 self ._on_prompt_start (flow_span , serialized , inputs , ** kwargs )
132156 else :
133157 span_type = 'chain'
134- if kwargs ['name' ] == 'LangGraph' : # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent
158+ if kwargs [
159+ 'name' ] == 'LangGraph' : # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent
135160 span_type = 'graph'
136161 flow_span = self ._new_flow_span (kwargs ['name' ], span_type , ** kwargs )
137162 flow_span .set_tags ({'input' : _convert_2_json (inputs )})
138163 except Exception as e :
139164 if flow_span is not None :
140- flow_span .set_error (e )
165+ span_tags = {"internal_error" : repr (e ), 'internal_error_trace' : traceback .format_exc ()}
166+ self ._set_span_tags (flow_span , span_tags , need_convert_tag_value = False )
167+ finally :
168+ if flow_span is not None :
169+ # set trace_id
170+ with self ._id_lock :
171+ if hasattr (flow_span , 'context' ):
172+ self .trace_id = flow_span .context .trace_id
173+ else :
174+ self .trace_id = flow_span .trace_id
175+ # set span_ctx in state
176+ if self ._state_span_ctx_key :
177+ inputs [self ._state_span_ctx_key ] = flow_span
178+ return flow_span
141179
142180 def on_chain_end (self , outputs : Union [Dict [str , Any ], Any ], ** kwargs : Any ) -> Any :
143181 flow_span = self .run_map [str (kwargs ['run_id' ])].span
@@ -151,16 +189,17 @@ def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> An
151189 else :
152190 flow_span .set_tags ({'output' : _convert_2_json (outputs )})
153191 except Exception as e :
154- flow_span .set_error (e )
192+ if flow_span :
193+ span_tags = {"internal_error" : repr (e ), 'internal_error_trace' : traceback .format_exc ()}
194+ self ._set_span_tags (flow_span , span_tags , need_convert_tag_value = False )
155195 self ._end_flow_span (flow_span )
156196
157197 def on_chain_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
158198 flow_span = self ._get_flow_span (** kwargs )
159199 if flow_span is None :
160200 span_name = '_Exception' if isinstance (error , Exception ) else '_KeyboardInterrupt'
161201 flow_span = self ._new_flow_span (span_name , 'chain_error' , ** kwargs )
162- flow_span .set_error (error )
163- flow_span .set_tags ({'error_trace' : traceback .format_exc ()})
202+ flow_span .set_tags ({'error' : repr (error ), 'error_trace' : traceback .format_exc ()})
164203 self ._end_flow_span (flow_span )
165204
166205 def on_tool_start (
@@ -170,13 +209,15 @@ def on_tool_start(
170209 span_name = serialized .get ('name' , 'unknown' )
171210 flow_span = self ._new_flow_span (span_name , 'tool' , ** kwargs )
172211 self ._set_span_tags (flow_span , span_tags )
212+ return flow_span
173213
174214 def on_tool_end (self , output : str , ** kwargs : Any ) -> Any :
175215 flow_span = self ._get_flow_span (** kwargs )
176216 try :
177217 flow_span .set_tags ({'output' : _convert_2_json (output )})
178218 except Exception as e :
179- flow_span .set_error (e )
219+ span_tags = {"internal_error" : repr (e ), 'internal_error_trace' : traceback .format_exc ()}
220+ self ._set_span_tags (flow_span , span_tags , need_convert_tag_value = False )
180221 self ._end_flow_span (flow_span )
181222
182223 def on_tool_error (
@@ -186,8 +227,8 @@ def on_tool_error(
186227 if flow_span is None :
187228 span_name = '_Exception' if isinstance (error , Exception ) else '_KeyboardInterrupt'
188229 flow_span = self ._new_flow_span (span_name , 'tool_error' , ** kwargs )
189- flow_span . set_error (error )
190- flow_span . set_tags ({ 'error_trace' : traceback . format_exc ()} )
230+ span_tags = { 'error' : repr (error ), 'error_trace' : traceback . format_exc ()}
231+ self . _set_span_tags ( flow_span , span_tags , need_convert_tag_value = False )
191232 self ._end_flow_span (flow_span )
192233
193234 def on_text (self , text : str , ** kwargs : Any ) -> Any :
@@ -200,7 +241,8 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
200241 return
201242
202243 def _end_flow_span (self , span : Span ):
203- span .finish ()
244+ if span :
245+ span .finish ()
204246
205247 def _get_model_tags (self , response : LLMResult , ** kwargs : Any ) -> Dict [str , Any ]:
206248 return self ._get_model_token_tags (response , ** kwargs )
@@ -224,20 +266,25 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str,
224266 result ['input_cached_tokens' ] = input_cached_tokens
225267 elif response .generations is not None and len (response .generations ) > 0 and response .generations [0 ] is not None :
226268 for i , generation in enumerate (response .generations [0 ]):
227- if isinstance (generation , ChatGeneration ) and isinstance (generation .message ,(AIMessageChunk , AIMessage )) and generation .message .usage_metadata :
269+ if isinstance (generation , ChatGeneration ) and isinstance (generation .message , (
270+ AIMessageChunk , AIMessage )) and generation .message .usage_metadata :
228271 is_get_from_langchain = True
229272 result ['input_tokens' ] = generation .message .usage_metadata .get ('input_tokens' , 0 )
230273 result ['output_tokens' ] = generation .message .usage_metadata .get ('output_tokens' , 0 )
231274 result ['tokens' ] = result ['input_tokens' ] + result ['output_tokens' ]
232275 if generation .message .usage_metadata .get ('output_token_details' , {}):
233- reasoning_tokens = generation .message .usage_metadata .get ('output_token_details' , {}).get ('reasoning' , 0 )
276+ reasoning_tokens = generation .message .usage_metadata .get ('output_token_details' , {}).get (
277+ 'reasoning' , 0 )
234278 if reasoning_tokens :
235279 result ['reasoning_tokens' ] = reasoning_tokens
236280 if generation .message .usage_metadata .get ('input_token_details' , {}):
237- input_read_cached_tokens = generation .message .usage_metadata .get ('input_token_details' , {}).get ('cache_read' , 0 )
281+ input_read_cached_tokens = generation .message .usage_metadata .get ('input_token_details' , {}).get (
282+ 'cache_read' , 0 )
238283 if input_read_cached_tokens :
239284 result ['input_cached_tokens' ] = input_read_cached_tokens
240- input_creation_cached_tokens = generation .message .usage_metadata .get ('input_token_details' , {}).get ('cache_creation' , 0 )
285+ input_creation_cached_tokens = generation .message .usage_metadata .get ('input_token_details' ,
286+ {}).get ('cache_creation' ,
287+ 0 )
241288 if input_creation_cached_tokens :
242289 result ['input_creation_cached_tokens' ] = input_creation_cached_tokens
243290 if is_get_from_langchain :
@@ -259,7 +306,8 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str,
259306 span_tags = {'error_info' : repr (e ), 'error_trace' : traceback .format_exc ()}
260307 return span_tags
261308
262- def _on_prompt_start (self , flow_span , serialized : Dict [str , Any ], inputs : (Dict [str , Any ], str ), ** kwargs : Any ) -> None :
309+ def _on_prompt_start (self , flow_span , serialized : Dict [str , Any ], inputs : (Dict [str , Any ], str ),
310+ ** kwargs : Any ) -> None :
263311 # get inputs
264312 params : List [Argument ] = []
265313 if isinstance (inputs , str ):
@@ -311,6 +359,9 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
311359 parent_span : Span = None
312360 if 'parent_run_id' in kwargs and kwargs ['parent_run_id' ] is not None and str (kwargs ['parent_run_id' ]) in self .run_map :
313361 parent_span = self .run_map [str (kwargs ['parent_run_id' ])].span
362+ # only root span use child_of
363+ if parent_span is None and self ._child_of :
364+ parent_span = self ._child_of
314365 # modify name
315366 error_tag = {}
316367 try :
@@ -321,15 +372,15 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
321372 except Exception as e :
322373 error_tag = {'error_info' : f'name_fn error { repr (e )} ' , 'error_trace' : traceback .format_exc ()}
323374 # new span
324- flow_span = _trace_callback_client .start_span (span_name , span_type , child_of = parent_span )
375+ flow_span = self . _client .start_span (span_name , span_type , child_of = parent_span )
325376 run_id = str (kwargs ['run_id' ])
326377 self .run_map [run_id ] = Run (run_id , flow_span , span_type )
327378 # set runtime
328379 flow_span .set_runtime (RuntimeInfo ())
329380 # set extra tags
330- flow_span .set_tags (self ._tags ) # global tags
381+ flow_span .set_tags (self ._tags ) # global tags
331382 try :
332- if self .tags_fn : # add tags fn
383+ if self .tags_fn : # add tags fn
333384 tags = self .tags_fn (node_name )
334385 if isinstance (tags , dict ):
335386 flow_span .set_tags (tags )
@@ -365,7 +416,10 @@ def _set_extra_span_tags(self, flow_span: Span, tag_list: list, **kwargs: Any):
365416class Run :
366417 def __init__ (self , run_id : str , span : Span , span_type : str ) -> None :
367418 self .run_id = run_id # langchain run_id
368- self .span_id = span .span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping.
419+ if hasattr (span , 'context' ):
420+ self .span_id = span .context .span_id
421+ else :
422+ self .span_id = span .span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping.
369423 self .span = span
370424 self .span_type = span_type
371425 self .child_runs : List [Run ] = Field (default_factory = list )
@@ -519,7 +573,8 @@ def _convert_inputs(inputs: Any) -> Any:
519573 format_inputs ['content' ] = inputs .content
520574 return format_inputs
521575 if isinstance (inputs , BaseMessage ):
522- message = Message (role = inputs .type , content = inputs .content , tool_calls = inputs .additional_kwargs .get ('tool_calls' , []))
576+ message = Message (role = inputs .type , content = inputs .content ,
577+ tool_calls = inputs .additional_kwargs .get ('tool_calls' , []))
523578 return message
524579 if isinstance (inputs , ChatPromptValue ):
525580 return _convert_inputs (inputs .messages )
0 commit comments