1+ from collections import defaultdict
12from contextvars import Token
23from typing import (
34 Any ,
2829 LangfuseSpan ,
2930 LangfuseTool ,
3031)
31- from langfuse .types import TraceContext
3232from langfuse ._utils import _get_timestamp
3333from langfuse .langchain .utils import _extract_model_name
3434from langfuse .logger import langfuse_logger
35+ from langfuse .types import TraceContext
3536
3637try :
3738 import langchain
@@ -132,6 +133,7 @@ def __init__(
132133 LangfuseRetriever ,
133134 ],
134135 ] = {}
136+ self ._child_to_parent_run_id_map : Dict [UUID , Optional [UUID ]] = defaultdict (None )
135137 self .context_tokens : Dict [UUID , Token ] = {}
136138 self .prompt_to_parent_run_map : Dict [UUID , Any ] = {}
137139 self .updated_completion_start_time_memo : Set [UUID ] = set ()
@@ -302,6 +304,8 @@ def on_chain_start(
302304 metadata : Optional [Dict [str , Any ]] = None ,
303305 ** kwargs : Any ,
304306 ) -> Any :
307+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
308+
305309 try :
306310 self ._log_debug_event (
307311 "on_chain_start" , run_id , parent_run_id , inputs = inputs
@@ -480,6 +484,8 @@ def on_agent_action(
480484 ** kwargs : Any ,
481485 ) -> Any :
482486 """Run on agent action."""
487+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
488+
483489 try :
484490 self ._log_debug_event (
485491 "on_agent_action" , run_id , parent_run_id , action = action
@@ -560,6 +566,10 @@ def on_chain_end(
560566 except Exception as e :
561567 langfuse_logger .exception (e )
562568
569+ finally :
570+ if parent_run_id is None :
571+ self ._reset ()
572+
563573 def on_chain_error (
564574 self ,
565575 error : BaseException ,
@@ -603,6 +613,8 @@ def on_chat_model_start(
603613 metadata : Optional [Dict [str , Any ]] = None ,
604614 ** kwargs : Any ,
605615 ) -> Any :
616+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
617+
606618 try :
607619 self ._log_debug_event (
608620 "on_chat_model_start" , run_id , parent_run_id , messages = messages
@@ -635,6 +647,8 @@ def on_llm_start(
635647 metadata : Optional [Dict [str , Any ]] = None ,
636648 ** kwargs : Any ,
637649 ) -> Any :
650+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
651+
638652 try :
639653 self ._log_debug_event (
640654 "on_llm_start" , run_id , parent_run_id , prompts = prompts
@@ -662,6 +676,8 @@ def on_tool_start(
662676 metadata : Optional [Dict [str , Any ]] = None ,
663677 ** kwargs : Any ,
664678 ) -> Any :
679+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
680+
665681 try :
666682 self ._log_debug_event (
667683 "on_tool_start" , run_id , parent_run_id , input_str = input_str
@@ -704,6 +720,8 @@ def on_retriever_start(
704720 metadata : Optional [Dict [str , Any ]] = None ,
705721 ** kwargs : Any ,
706722 ) -> Any :
723+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
724+
707725 try :
708726 self ._log_debug_event (
709727 "on_retriever_start" , run_id , parent_run_id , query = query
@@ -809,6 +827,8 @@ def __on_llm_action(
809827 metadata : Optional [Dict [str , Any ]] = None ,
810828 ** kwargs : Any ,
811829 ) -> None :
830+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
831+
812832 try :
813833 tools = kwargs .get ("invocation_params" , {}).get ("tools" , None )
814834 if tools and isinstance (tools , list ):
@@ -817,14 +837,23 @@ def __on_llm_action(
817837 model_name = self ._parse_model_and_log_errors (
818838 serialized = serialized , metadata = metadata , kwargs = kwargs
819839 )
820- registered_prompt = (
821- self .prompt_to_parent_run_map .get (parent_run_id )
822- if parent_run_id is not None
823- else None
824- )
825840
826- if registered_prompt :
827- self ._deregister_langfuse_prompt (parent_run_id )
841+ registered_prompt = None
842+ current_parent_run_id = parent_run_id
843+
844+ # Check all parents for registered prompt
845+ while current_parent_run_id is not None :
846+ registered_prompt = self .prompt_to_parent_run_map .get (
847+ current_parent_run_id
848+ )
849+
850+ if registered_prompt :
851+ self ._deregister_langfuse_prompt (current_parent_run_id )
852+ break
853+ else :
854+ current_parent_run_id = self ._child_to_parent_run_id_map .get (
855+ current_parent_run_id
856+ )
828857
829858 content = {
830859 "name" : self .get_langchain_run_name (serialized , ** kwargs ),
@@ -956,6 +985,9 @@ def on_llm_end(
956985 finally :
957986 self .updated_completion_start_time_memo .discard (run_id )
958987
988+ if parent_run_id is None :
989+ self ._reset ()
990+
959991 def on_llm_error (
960992 self ,
961993 error : BaseException ,
@@ -980,6 +1012,9 @@ def on_llm_error(
9801012 except Exception as e :
9811013 langfuse_logger .exception (e )
9821014
1015+ def _reset (self ) -> None :
1016+ self ._child_to_parent_run_id_map = defaultdict (None )
1017+
9831018 def __join_tags_and_metadata (
9841019 self ,
9851020 tags : Optional [List [str ]] = None ,
@@ -1047,7 +1082,7 @@ def _log_debug_event(
10471082 ** kwargs : Any ,
10481083 ) -> None :
10491084 langfuse_logger .debug (
1050- f"Event: { event_name } , run_id: { str ( run_id )[: 5 ] } , parent_run_id: { str ( parent_run_id )[: 5 ] } "
1085+ f"Event: { event_name } , run_id: { run_id } , parent_run_id: { parent_run_id } "
10511086 )
10521087
10531088
@@ -1210,7 +1245,9 @@ def _parse_usage_model(usage: Union[pydantic.BaseModel, dict]) -> Any:
12101245 usage_model ["input" ] = max (0 , usage_model ["input" ] - value )
12111246
12121247 if f"input_modality_{ item ['modality' ]} " in usage_model :
1213- usage_model [f"input_modality_{ item ['modality' ]} " ] = max (0 , usage_model [f"input_modality_{ item ['modality' ]} " ] - value )
1248+ usage_model [f"input_modality_{ item ['modality' ]} " ] = max (
1249+ 0 , usage_model [f"input_modality_{ item ['modality' ]} " ] - value
1250+ )
12141251
12151252 usage_model = {k : v for k , v in usage_model .items () if isinstance (v , int )}
12161253
0 commit comments