Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions langfuse/langchain/CallbackHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
LangfuseSpan,
LangfuseTool,
)
from langfuse.types import TraceContext
from langfuse._utils import _get_timestamp
from langfuse.langchain.utils import _extract_model_name
from langfuse.logger import langfuse_logger
from langfuse.types import TraceContext

try:
import langchain
Expand Down Expand Up @@ -132,6 +132,7 @@ def __init__(
LangfuseRetriever,
],
] = {}
self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {}
self.context_tokens: Dict[UUID, Token] = {}
self.prompt_to_parent_run_map: Dict[UUID, Any] = {}
self.updated_completion_start_time_memo: Set[UUID] = set()
Expand Down Expand Up @@ -302,6 +303,8 @@ def on_chain_start(
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self._child_to_parent_run_id_map[run_id] = parent_run_id

try:
self._log_debug_event(
"on_chain_start", run_id, parent_run_id, inputs=inputs
Expand Down Expand Up @@ -480,6 +483,8 @@ def on_agent_action(
**kwargs: Any,
) -> Any:
"""Run on agent action."""
self._child_to_parent_run_id_map[run_id] = parent_run_id

try:
self._log_debug_event(
"on_agent_action", run_id, parent_run_id, action=action
Expand Down Expand Up @@ -560,6 +565,10 @@ def on_chain_end(
except Exception as e:
langfuse_logger.exception(e)

finally:
if parent_run_id is None:
self._reset()

def on_chain_error(
self,
error: BaseException,
Expand Down Expand Up @@ -603,6 +612,8 @@ def on_chat_model_start(
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self._child_to_parent_run_id_map[run_id] = parent_run_id

try:
self._log_debug_event(
"on_chat_model_start", run_id, parent_run_id, messages=messages
Expand Down Expand Up @@ -635,6 +646,8 @@ def on_llm_start(
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self._child_to_parent_run_id_map[run_id] = parent_run_id

try:
self._log_debug_event(
"on_llm_start", run_id, parent_run_id, prompts=prompts
Expand Down Expand Up @@ -662,6 +675,8 @@ def on_tool_start(
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self._child_to_parent_run_id_map[run_id] = parent_run_id

try:
self._log_debug_event(
"on_tool_start", run_id, parent_run_id, input_str=input_str
Expand Down Expand Up @@ -704,6 +719,8 @@ def on_retriever_start(
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self._child_to_parent_run_id_map[run_id] = parent_run_id

try:
self._log_debug_event(
"on_retriever_start", run_id, parent_run_id, query=query
Expand Down Expand Up @@ -809,6 +826,8 @@ def __on_llm_action(
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self._child_to_parent_run_id_map[run_id] = parent_run_id

try:
tools = kwargs.get("invocation_params", {}).get("tools", None)
if tools and isinstance(tools, list):
Expand All @@ -817,14 +836,23 @@ def __on_llm_action(
model_name = self._parse_model_and_log_errors(
serialized=serialized, metadata=metadata, kwargs=kwargs
)
registered_prompt = (
self.prompt_to_parent_run_map.get(parent_run_id)
if parent_run_id is not None
else None
)

if registered_prompt:
self._deregister_langfuse_prompt(parent_run_id)
registered_prompt = None
current_parent_run_id = parent_run_id

# Check all parents for registered prompt
while current_parent_run_id is not None:
registered_prompt = self.prompt_to_parent_run_map.get(
current_parent_run_id
)

if registered_prompt:
self._deregister_langfuse_prompt(current_parent_run_id)
break
else:
current_parent_run_id = self._child_to_parent_run_id_map.get(
current_parent_run_id, None
)

content = {
"name": self.get_langchain_run_name(serialized, **kwargs),
Expand Down Expand Up @@ -956,6 +984,9 @@ def on_llm_end(
finally:
self.updated_completion_start_time_memo.discard(run_id)

if parent_run_id is None:
self._reset()

def on_llm_error(
self,
error: BaseException,
Expand All @@ -980,6 +1011,9 @@ def on_llm_error(
except Exception as e:
langfuse_logger.exception(e)

def _reset(self) -> None:
self._child_to_parent_run_id_map = {}

def __join_tags_and_metadata(
self,
tags: Optional[List[str]] = None,
Expand Down Expand Up @@ -1047,7 +1081,7 @@ def _log_debug_event(
**kwargs: Any,
) -> None:
langfuse_logger.debug(
f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}"
f"Event: {event_name}, run_id: {run_id}, parent_run_id: {parent_run_id}"
)


Expand Down Expand Up @@ -1210,7 +1244,9 @@ def _parse_usage_model(usage: Union[pydantic.BaseModel, dict]) -> Any:
usage_model["input"] = max(0, usage_model["input"] - value)

if f"input_modality_{item['modality']}" in usage_model:
usage_model[f"input_modality_{item['modality']}"] = max(0, usage_model[f"input_modality_{item['modality']}"] - value)
usage_model[f"input_modality_{item['modality']}"] = max(
0, usage_model[f"input_modality_{item['modality']}"] - value
)

usage_model = {k: v for k, v in usage_model.items() if isinstance(v, int)}

Expand Down
Loading