Skip to content

Commit 012fbac

Browse files
committed
fix(langchain): allow prompt linking with langchain v1 create_agent
1 parent 32dddd5 commit 012fbac

File tree

1 file changed

+47
-10
lines changed

1 file changed

+47
-10
lines changed

langfuse/langchain/CallbackHandler.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
from contextvars import Token
23
from typing import (
34
Any,
@@ -28,10 +29,10 @@
2829
LangfuseSpan,
2930
LangfuseTool,
3031
)
31-
from langfuse.types import TraceContext
3232
from langfuse._utils import _get_timestamp
3333
from langfuse.langchain.utils import _extract_model_name
3434
from langfuse.logger import langfuse_logger
35+
from langfuse.types import TraceContext
3536

3637
try:
3738
import langchain
@@ -132,6 +133,7 @@ def __init__(
132133
LangfuseRetriever,
133134
],
134135
] = {}
136+
self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = defaultdict(None)
135137
self.context_tokens: Dict[UUID, Token] = {}
136138
self.prompt_to_parent_run_map: Dict[UUID, Any] = {}
137139
self.updated_completion_start_time_memo: Set[UUID] = set()
@@ -302,6 +304,8 @@ def on_chain_start(
302304
metadata: Optional[Dict[str, Any]] = None,
303305
**kwargs: Any,
304306
) -> Any:
307+
self._child_to_parent_run_id_map[run_id] = parent_run_id
308+
305309
try:
306310
self._log_debug_event(
307311
"on_chain_start", run_id, parent_run_id, inputs=inputs
@@ -480,6 +484,8 @@ def on_agent_action(
480484
**kwargs: Any,
481485
) -> Any:
482486
"""Run on agent action."""
487+
self._child_to_parent_run_id_map[run_id] = parent_run_id
488+
483489
try:
484490
self._log_debug_event(
485491
"on_agent_action", run_id, parent_run_id, action=action
@@ -560,6 +566,10 @@ def on_chain_end(
560566
except Exception as e:
561567
langfuse_logger.exception(e)
562568

569+
finally:
570+
if parent_run_id is None:
571+
self._reset()
572+
563573
def on_chain_error(
564574
self,
565575
error: BaseException,
@@ -603,6 +613,8 @@ def on_chat_model_start(
603613
metadata: Optional[Dict[str, Any]] = None,
604614
**kwargs: Any,
605615
) -> Any:
616+
self._child_to_parent_run_id_map[run_id] = parent_run_id
617+
606618
try:
607619
self._log_debug_event(
608620
"on_chat_model_start", run_id, parent_run_id, messages=messages
@@ -635,6 +647,8 @@ def on_llm_start(
635647
metadata: Optional[Dict[str, Any]] = None,
636648
**kwargs: Any,
637649
) -> Any:
650+
self._child_to_parent_run_id_map[run_id] = parent_run_id
651+
638652
try:
639653
self._log_debug_event(
640654
"on_llm_start", run_id, parent_run_id, prompts=prompts
@@ -662,6 +676,8 @@ def on_tool_start(
662676
metadata: Optional[Dict[str, Any]] = None,
663677
**kwargs: Any,
664678
) -> Any:
679+
self._child_to_parent_run_id_map[run_id] = parent_run_id
680+
665681
try:
666682
self._log_debug_event(
667683
"on_tool_start", run_id, parent_run_id, input_str=input_str
@@ -704,6 +720,8 @@ def on_retriever_start(
704720
metadata: Optional[Dict[str, Any]] = None,
705721
**kwargs: Any,
706722
) -> Any:
723+
self._child_to_parent_run_id_map[run_id] = parent_run_id
724+
707725
try:
708726
self._log_debug_event(
709727
"on_retriever_start", run_id, parent_run_id, query=query
@@ -809,6 +827,8 @@ def __on_llm_action(
809827
metadata: Optional[Dict[str, Any]] = None,
810828
**kwargs: Any,
811829
) -> None:
830+
self._child_to_parent_run_id_map[run_id] = parent_run_id
831+
812832
try:
813833
tools = kwargs.get("invocation_params", {}).get("tools", None)
814834
if tools and isinstance(tools, list):
@@ -817,14 +837,23 @@ def __on_llm_action(
817837
model_name = self._parse_model_and_log_errors(
818838
serialized=serialized, metadata=metadata, kwargs=kwargs
819839
)
820-
registered_prompt = (
821-
self.prompt_to_parent_run_map.get(parent_run_id)
822-
if parent_run_id is not None
823-
else None
824-
)
825840

826-
if registered_prompt:
827-
self._deregister_langfuse_prompt(parent_run_id)
841+
registered_prompt = None
842+
current_parent_run_id = parent_run_id
843+
844+
# Check all parents for registered prompt
845+
while current_parent_run_id is not None:
846+
registered_prompt = self.prompt_to_parent_run_map.get(
847+
current_parent_run_id
848+
)
849+
850+
if registered_prompt:
851+
self._deregister_langfuse_prompt(current_parent_run_id)
852+
break
853+
else:
854+
current_parent_run_id = self._child_to_parent_run_id_map.get(
855+
current_parent_run_id
856+
)
828857

829858
content = {
830859
"name": self.get_langchain_run_name(serialized, **kwargs),
@@ -956,6 +985,9 @@ def on_llm_end(
956985
finally:
957986
self.updated_completion_start_time_memo.discard(run_id)
958987

988+
if parent_run_id is None:
989+
self._reset()
990+
959991
def on_llm_error(
960992
self,
961993
error: BaseException,
@@ -980,6 +1012,9 @@ def on_llm_error(
9801012
except Exception as e:
9811013
langfuse_logger.exception(e)
9821014

1015+
def _reset(self) -> None:
1016+
self._child_to_parent_run_id_map = defaultdict(None)
1017+
9831018
def __join_tags_and_metadata(
9841019
self,
9851020
tags: Optional[List[str]] = None,
@@ -1047,7 +1082,7 @@ def _log_debug_event(
10471082
**kwargs: Any,
10481083
) -> None:
10491084
langfuse_logger.debug(
1050-
f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}"
1085+
f"Event: {event_name}, run_id: {run_id}, parent_run_id: {parent_run_id}"
10511086
)
10521087

10531088

@@ -1210,7 +1245,9 @@ def _parse_usage_model(usage: Union[pydantic.BaseModel, dict]) -> Any:
12101245
usage_model["input"] = max(0, usage_model["input"] - value)
12111246

12121247
if f"input_modality_{item['modality']}" in usage_model:
1213-
usage_model[f"input_modality_{item['modality']}"] = max(0, usage_model[f"input_modality_{item['modality']}"] - value)
1248+
usage_model[f"input_modality_{item['modality']}"] = max(
1249+
0, usage_model[f"input_modality_{item['modality']}"] - value
1250+
)
12141251

12151252
usage_model = {k: v for k, v in usage_model.items() if isinstance(v, int)}
12161253

0 commit comments

Comments
 (0)