diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index b3cc76df4..ed8046e8e 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -1,16 +1,18 @@ import typing +from contextvars import Token import pydantic +from opentelemetry import context, trace -from langfuse._client.get_client import get_client from langfuse._client.attributes import LangfuseOtelSpanAttributes +from langfuse._client.get_client import get_client from langfuse._client.span import ( - LangfuseGeneration, - LangfuseSpan, LangfuseAgent, LangfuseChain, - LangfuseTool, + LangfuseGeneration, LangfuseRetriever, + LangfuseSpan, + LangfuseTool, ) from langfuse.logger import langfuse_logger @@ -86,6 +88,7 @@ def __init__( LangfuseRetriever, ], ] = {} + self.context_tokens: Dict[UUID, Token] = {} self.prompt_to_parent_run_map: Dict[UUID, Any] = {} self.updated_completion_start_time_memo: Set[UUID] = set() @@ -210,11 +213,14 @@ def on_retriever_error( if run_id is None or run_id not in self.runs: raise Exception("run not found") - self.runs[run_id].update( - level="ERROR", - status_message=str(error), - input=kwargs.get("inputs"), - ).end() + observation = self._detach_observation(run_id) + + if observation is not None: + observation.update( + level="ERROR", + status_message=str(error), + input=kwargs.get("inputs"), + ).end() except Exception as e: langfuse_logger.exception(e) @@ -270,17 +276,19 @@ def on_chain_start( serialized, "chain", **kwargs ) + span = self.client.start_observation( + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=inputs, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), + ) + self._attach_observation(run_id, span) + if parent_run_id is None: - span = self.client.start_observation( - name=span_name, - as_type=observation_type, - metadata=span_metadata, - input=inputs, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) span.update_trace( **( cast( @@ -296,21 +304,6 @@ def on_chain_start( ), **self._parse_langfuse_trace_attributes_from_metadata(metadata), ) - self.runs[run_id] = span - else: - self.runs[run_id] = cast( - LangfuseChain, - self.runs[parent_run_id], - ).start_observation( - name=span_name, - as_type=observation_type, - metadata=span_metadata, - input=inputs, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) self.last_trace_id = self.runs[run_id].trace_id @@ -347,6 +340,53 @@ def _deregister_langfuse_prompt(self, run_id: Optional[UUID]) -> None: if run_id is not None and run_id in self.prompt_to_parent_run_map: del self.prompt_to_parent_run_map[run_id] + def _attach_observation( + self, + run_id: UUID, + observation: Union[ + LangfuseAgent, + LangfuseChain, + LangfuseGeneration, + LangfuseRetriever, + LangfuseSpan, + LangfuseTool, + ], + ) -> None: + ctx = trace.set_span_in_context(observation._otel_span) + token = context.attach(ctx) + + self.runs[run_id] = observation + self.context_tokens[run_id] = token + + def _detach_observation( + self, run_id: UUID + ) -> Optional[ + Union[ + LangfuseAgent, + LangfuseChain, + LangfuseGeneration, + LangfuseRetriever, + LangfuseSpan, + LangfuseTool, + ] + ]: + token = self.context_tokens.pop(run_id, None) + + if token: + context.detach(token) + + return cast( + Union[ + LangfuseAgent, + LangfuseChain, + LangfuseGeneration, + LangfuseRetriever, + LangfuseSpan, + LangfuseTool, + ], + self.runs.pop(run_id, None), + ) + def on_agent_action( self, action: AgentAction, @@ -393,16 +433,17 @@ def on_agent_finish( if run_id not in self.runs: raise Exception("run not found") - agent_run = self.runs[run_id] - if hasattr(agent_run, "_otel_span"): + agent_run = self._detach_observation(run_id) + + if agent_run is not None: agent_run._otel_span.set_attribute( LangfuseOtelSpanAttributes.OBSERVATION_TYPE, "agent" ) - agent_run.update( - output=finish, - input=kwargs.get("inputs"), - ).end() + agent_run.update( + output=finish, + input=kwargs.get("inputs"), + ).end() except Exception as e: langfuse_logger.exception(e) @@ -423,20 +464,20 @@ def on_chain_end( if run_id not in self.runs: raise Exception("run not found") - span = self.runs[run_id] - span.update( - output=outputs, - input=kwargs.get("inputs"), - ) + span = self._detach_observation(run_id) - if parent_run_id is None and self.update_trace: - span.update_trace(output=outputs, input=kwargs.get("inputs")) + if span is not None: + span.update( + output=outputs, + input=kwargs.get("inputs"), + ) - span.end() + if parent_run_id is None and self.update_trace: + span.update_trace(output=outputs, input=kwargs.get("inputs")) - del self.runs[run_id] + span.end() - self._deregister_langfuse_prompt(run_id) + self._deregister_langfuse_prompt(run_id) except Exception as e: langfuse_logger.exception(e) @@ -452,26 +493,23 @@ def on_chain_error( ) -> None: try: self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error) - if run_id in self.runs: - if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES): - level = None - else: - level = "ERROR" + if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES): + level = None + else: + level = "ERROR" + + observation = self._detach_observation(run_id) - self.runs[run_id].update( + if observation is not None: + observation.update( level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], level + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + level, ), status_message=str(error) if level else None, input=kwargs.get("inputs"), ).end() - del self.runs[run_id] - else: - langfuse_logger.warning( - f"Run ID {run_id} already popped from run map. Could not update run with error message" - ) - except Exception as e: langfuse_logger.exception(e) @@ -563,26 +601,15 @@ def on_tool_start( serialized, "tool", **kwargs ) - if parent_run_id is None or parent_run_id not in self.runs: - # Create root observation for direct tool calls - self.runs[run_id] = self.client.start_observation( - name=self.get_langchain_run_name(serialized, **kwargs), - as_type=observation_type, - input=input_str, - metadata=meta, - level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, - ) - else: - # Create child observation for tools within chains/agents - self.runs[run_id] = cast( - LangfuseChain, self.runs[parent_run_id] - ).start_observation( - name=self.get_langchain_run_name(serialized, **kwargs), - as_type=observation_type, - input=input_str, - metadata=meta, - level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, - ) + span = self.client.start_observation( + name=self.get_langchain_run_name(serialized, **kwargs), + as_type=observation_type, + input=input_str, + metadata=meta, + level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, + ) + + self._attach_observation(run_id, span) except Exception as e: langfuse_logger.exception(e) @@ -610,30 +637,18 @@ def on_retriever_start( serialized, "retriever", **kwargs ) - if parent_run_id is None: - self.runs[run_id] = self.client.start_observation( - name=span_name, - as_type=observation_type, - metadata=span_metadata, - input=query, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) - else: - self.runs[run_id] = cast( - LangfuseRetriever, self.runs[parent_run_id] - ).start_observation( - name=span_name, - as_type=observation_type, - input=query, - metadata=span_metadata, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) + span = self.client.start_observation( + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=query, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), + ) + + self._attach_observation(run_id, span) except Exception as e: langfuse_logger.exception(e) @@ -650,15 +665,13 @@ def on_retriever_end( self._log_debug_event( "on_retriever_end", run_id, parent_run_id, documents=documents ) - if run_id is None or run_id not in self.runs: - raise Exception("run not found") + observation = self._detach_observation(run_id) - self.runs[run_id].update( - output=documents, - input=kwargs.get("inputs"), - ).end() - - del self.runs[run_id] + if observation is not None: + observation.update( + output=documents, + input=kwargs.get("inputs"), + ).end() except Exception as e: langfuse_logger.exception(e) @@ -673,15 +686,14 @@ def on_tool_end( ) -> Any: try: self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output) - if run_id is None or run_id not in self.runs: - raise Exception("run not found") - self.runs[run_id].update( - output=output, - input=kwargs.get("inputs"), - ).end() + observation = self._detach_observation(run_id) - del self.runs[run_id] + if observation is not None: + observation.update( + output=output, + input=kwargs.get("inputs"), + ).end() except Exception as e: langfuse_logger.exception(e) @@ -696,16 +708,14 @@ def on_tool_error( ) -> Any: try: self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error) - if run_id is None or run_id not in self.runs: - raise Exception("run not found") - - self.runs[run_id].update( - status_message=str(error), - level="ERROR", - input=kwargs.get("inputs"), - ).end() + observation = self._detach_observation(run_id) - del self.runs[run_id] + if observation is not None: + observation.update( + status_message=str(error), + level="ERROR", + input=kwargs.get("inputs"), + ).end() except Exception as e: langfuse_logger.exception(e) @@ -753,14 +763,8 @@ def __on_llm_action( "prompt": registered_prompt, } - if parent_run_id is not None and parent_run_id in self.runs: - self.runs[run_id] = cast( - LangfuseGeneration, self.runs[parent_run_id] - ).start_observation(as_type="generation", **content) # type: ignore - else: - self.runs[run_id] = self.client.start_observation( - as_type="generation", **content - ) # type: ignore + generation = self.client.start_observation(as_type="generation", **content) # type: ignore + self._attach_observation(run_id, generation) self.last_trace_id = self.runs[run_id].trace_id @@ -856,17 +860,17 @@ def on_llm_end( # e.g. azure returns the model name in the response model = _parse_model(response) - langfuse_generation = cast(LangfuseGeneration, self.runs[run_id]) - langfuse_generation.update( - output=extracted_response, - usage=llm_usage, - usage_details=llm_usage, - input=kwargs.get("inputs"), - model=model, - ) - langfuse_generation.end() - del self.runs[run_id] + generation = self._detach_observation(run_id) + + if generation is not None: + generation.update( + output=extracted_response, + usage=llm_usage, + usage_details=llm_usage, + input=kwargs.get("inputs"), + model=model, + ).end() except Exception as e: langfuse_logger.exception(e) @@ -884,16 +888,15 @@ def on_llm_error( ) -> Any: try: self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error) - if run_id in self.runs: - generation = self.runs[run_id] + + generation = self._detach_observation(run_id) + + if generation is not None: generation.update( status_message=str(error), level="ERROR", input=kwargs.get("inputs"), - ) - generation.end() - - del self.runs[run_id] + ).end() except Exception as e: langfuse_logger.exception(e)