From 24311925d311a1d0b88bf6f6de52120854d0ac64 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Tue, 26 Aug 2025 17:28:22 +0200 Subject: [PATCH 1/3] feat(langchain): set callback spans to active in OTEL context --- langfuse/langchain/CallbackHandler.py | 232 +++++++++++++++++--------- 1 file changed, 149 insertions(+), 83 deletions(-) diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index b3cc76df4..03d84522e 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -1,16 +1,18 @@ import typing +from contextvars import Token import pydantic +from opentelemetry import context, trace -from langfuse._client.get_client import get_client from langfuse._client.attributes import LangfuseOtelSpanAttributes +from langfuse._client.get_client import get_client from langfuse._client.span import ( - LangfuseGeneration, - LangfuseSpan, LangfuseAgent, LangfuseChain, - LangfuseTool, + LangfuseGeneration, LangfuseRetriever, + LangfuseSpan, + LangfuseTool, ) from langfuse.logger import langfuse_logger @@ -86,6 +88,7 @@ def __init__( LangfuseRetriever, ], ] = {} + self.context_tokens: Dict[UUID, Token] = {} self.prompt_to_parent_run_map: Dict[UUID, Any] = {} self.updated_completion_start_time_memo: Set[UUID] = set() @@ -210,11 +213,14 @@ def on_retriever_error( if run_id is None or run_id not in self.runs: raise Exception("run not found") - self.runs[run_id].update( - level="ERROR", - status_message=str(error), - input=kwargs.get("inputs"), - ).end() + observation = self._detach_observation(run_id) + + if observation is not None: + observation.update( + level="ERROR", + status_message=str(error), + input=kwargs.get("inputs"), + ).end() except Exception as e: langfuse_logger.exception(e) @@ -281,6 +287,8 @@ def on_chain_start( span_level, ), ) + self._attach_observation(run_id, span) + span.update_trace( **( cast( @@ -296,9 +304,8 @@ def on_chain_start( ), **self._parse_langfuse_trace_attributes_from_metadata(metadata), ) - self.runs[run_id] = span else: - self.runs[run_id] = cast( + span = cast( LangfuseChain, self.runs[parent_run_id], ).start_observation( @@ -312,6 +319,8 @@ def on_chain_start( ), ) + self._attach_observation(run_id, span) + self.last_trace_id = self.runs[run_id].trace_id except Exception as e: @@ -347,6 +356,53 @@ def _deregister_langfuse_prompt(self, run_id: Optional[UUID]) -> None: if run_id is not None and run_id in self.prompt_to_parent_run_map: del self.prompt_to_parent_run_map[run_id] + def _attach_observation( + self, + run_id: UUID, + observation: Union[ + LangfuseAgent, + LangfuseChain, + LangfuseGeneration, + LangfuseRetriever, + LangfuseSpan, + LangfuseTool, + ], + ) -> None: + ctx = trace.set_span_in_context(observation._otel_span) + token = context.attach(ctx) + + self.runs[run_id] = observation + self.context_tokens[run_id] = token + + def _detach_observation( + self, run_id: UUID + ) -> Optional[ + Union[ + LangfuseAgent, + LangfuseChain, + LangfuseGeneration, + LangfuseRetriever, + LangfuseSpan, + LangfuseTool, + ] + ]: + token = self.context_tokens.pop(run_id, None) + + if token: + context.detach(token) + + return cast( + Union[ + LangfuseAgent, + LangfuseChain, + LangfuseGeneration, + LangfuseRetriever, + LangfuseSpan, + LangfuseTool, + ], + self.runs.pop(run_id, None), + ) + def on_agent_action( self, action: AgentAction, @@ -393,16 +449,17 @@ def on_agent_finish( if run_id not in self.runs: raise Exception("run not found") - agent_run = self.runs[run_id] - if hasattr(agent_run, "_otel_span"): + agent_run = self._detach_observation(run_id) + + if agent_run is not None: agent_run._otel_span.set_attribute( LangfuseOtelSpanAttributes.OBSERVATION_TYPE, "agent" ) - agent_run.update( - output=finish, - input=kwargs.get("inputs"), - ).end() + agent_run.update( + output=finish, + input=kwargs.get("inputs"), + ).end() except Exception as e: langfuse_logger.exception(e) @@ -423,20 +480,20 @@ def on_chain_end( if run_id not in self.runs: raise Exception("run not found") - span = self.runs[run_id] - span.update( - output=outputs, - input=kwargs.get("inputs"), - ) + span = self._detach_observation(run_id) - if parent_run_id is None and self.update_trace: - span.update_trace(output=outputs, input=kwargs.get("inputs")) + if span is not None: + span.update( + output=outputs, + input=kwargs.get("inputs"), + ) - span.end() + if parent_run_id is None and self.update_trace: + span.update_trace(output=outputs, input=kwargs.get("inputs")) - del self.runs[run_id] + span.end() - self._deregister_langfuse_prompt(run_id) + self._deregister_langfuse_prompt(run_id) except Exception as e: langfuse_logger.exception(e) @@ -458,15 +515,18 @@ def on_chain_error( else: level = "ERROR" - self.runs[run_id].update( - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], level - ), - status_message=str(error) if level else None, - input=kwargs.get("inputs"), - ).end() + observation = self._detach_observation(run_id) + + if observation is not None: + observation.update( + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + level, + ), + status_message=str(error) if level else None, + input=kwargs.get("inputs"), + ).end() - del self.runs[run_id] else: langfuse_logger.warning( f"Run ID {run_id} already popped from run map. Could not update run with error message" @@ -565,18 +625,19 @@ def on_tool_start( if parent_run_id is None or parent_run_id not in self.runs: # Create root observation for direct tool calls - self.runs[run_id] = self.client.start_observation( + span = self.client.start_observation( name=self.get_langchain_run_name(serialized, **kwargs), as_type=observation_type, input=input_str, metadata=meta, level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, ) + + self._attach_observation(run_id, span) + else: # Create child observation for tools within chains/agents - self.runs[run_id] = cast( - LangfuseChain, self.runs[parent_run_id] - ).start_observation( + span = cast(LangfuseChain, self.runs[parent_run_id]).start_observation( name=self.get_langchain_run_name(serialized, **kwargs), as_type=observation_type, input=input_str, @@ -584,6 +645,8 @@ def on_tool_start( level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, ) + self._attach_observation(run_id, span) + except Exception as e: langfuse_logger.exception(e) @@ -611,7 +674,7 @@ def on_retriever_start( ) if parent_run_id is None: - self.runs[run_id] = self.client.start_observation( + span = self.client.start_observation( name=span_name, as_type=observation_type, metadata=span_metadata, @@ -621,8 +684,11 @@ def on_retriever_start( span_level, ), ) + + self._attach_observation(run_id, span) + else: - self.runs[run_id] = cast( + span = cast( LangfuseRetriever, self.runs[parent_run_id] ).start_observation( name=span_name, @@ -635,6 +701,8 @@ def on_retriever_start( ), ) + self._attach_observation(run_id, span) + except Exception as e: langfuse_logger.exception(e) @@ -650,15 +718,13 @@ def on_retriever_end( self._log_debug_event( "on_retriever_end", run_id, parent_run_id, documents=documents ) - if run_id is None or run_id not in self.runs: - raise Exception("run not found") - - self.runs[run_id].update( - output=documents, - input=kwargs.get("inputs"), - ).end() + observation = self._detach_observation(run_id) - del self.runs[run_id] + if observation is not None: + observation.update( + output=documents, + input=kwargs.get("inputs"), + ).end() except Exception as e: langfuse_logger.exception(e) @@ -673,15 +739,14 @@ def on_tool_end( ) -> Any: try: self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output) - if run_id is None or run_id not in self.runs: - raise Exception("run not found") - self.runs[run_id].update( - output=output, - input=kwargs.get("inputs"), - ).end() + observation = self._detach_observation(run_id) - del self.runs[run_id] + if observation is not None: + observation.update( + output=output, + input=kwargs.get("inputs"), + ).end() except Exception as e: langfuse_logger.exception(e) @@ -696,16 +761,14 @@ def on_tool_error( ) -> Any: try: self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error) - if run_id is None or run_id not in self.runs: - raise Exception("run not found") - - self.runs[run_id].update( - status_message=str(error), - level="ERROR", - input=kwargs.get("inputs"), - ).end() + observation = self._detach_observation(run_id) - del self.runs[run_id] + if observation is not None: + observation.update( + status_message=str(error), + level="ERROR", + input=kwargs.get("inputs"), + ).end() except Exception as e: langfuse_logger.exception(e) @@ -754,14 +817,18 @@ def __on_llm_action( } if parent_run_id is not None and parent_run_id in self.runs: - self.runs[run_id] = cast( + generation = cast( LangfuseGeneration, self.runs[parent_run_id] ).start_observation(as_type="generation", **content) # type: ignore + + self._attach_observation(run_id, generation) else: - self.runs[run_id] = self.client.start_observation( + generation = self.client.start_observation( as_type="generation", **content ) # type: ignore + self._attach_observation(run_id, generation) + self.last_trace_id = self.runs[run_id].trace_id except Exception as e: @@ -856,17 +923,17 @@ def on_llm_end( # e.g. azure returns the model name in the response model = _parse_model(response) - langfuse_generation = cast(LangfuseGeneration, self.runs[run_id]) - langfuse_generation.update( - output=extracted_response, - usage=llm_usage, - usage_details=llm_usage, - input=kwargs.get("inputs"), - model=model, - ) - langfuse_generation.end() - del self.runs[run_id] + generation = self._detach_observation(run_id) + + if generation is not None: + generation.update( + output=extracted_response, + usage=llm_usage, + usage_details=llm_usage, + input=kwargs.get("inputs"), + model=model, + ).end() except Exception as e: langfuse_logger.exception(e) @@ -884,16 +951,15 @@ def on_llm_error( ) -> Any: try: self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error) - if run_id in self.runs: - generation = self.runs[run_id] + + generation = self._detach_observation(run_id) + + if generation is not None: generation.update( status_message=str(error), level="ERROR", input=kwargs.get("inputs"), - ) - generation.end() - - del self.runs[run_id] + ).end() except Exception as e: langfuse_logger.exception(e) From 9f4cafefcb4ebff83b792993065c5ebb89fc3ce8 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Tue, 26 Aug 2025 17:46:27 +0200 Subject: [PATCH 2/3] simplify observation creation --- langfuse/langchain/CallbackHandler.py | 159 ++++++++------------------ 1 file changed, 48 insertions(+), 111 deletions(-) diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 03d84522e..f13647fb8 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -276,19 +276,19 @@ def on_chain_start( serialized, "chain", **kwargs ) - if parent_run_id is None: - span = self.client.start_observation( - name=span_name, - as_type=observation_type, - metadata=span_metadata, - input=inputs, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) - self._attach_observation(run_id, span) + span = self.client.start_observation( + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=inputs, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), + ) + self._attach_observation(run_id, span) + if parent_run_id is None: span.update_trace( **( cast( @@ -304,22 +304,6 @@ def on_chain_start( ), **self._parse_langfuse_trace_attributes_from_metadata(metadata), ) - else: - span = cast( - LangfuseChain, - self.runs[parent_run_id], - ).start_observation( - name=span_name, - as_type=observation_type, - metadata=span_metadata, - input=inputs, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) - - self._attach_observation(run_id, span) self.last_trace_id = self.runs[run_id].trace_id @@ -509,28 +493,22 @@ def on_chain_error( ) -> None: try: self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error) - if run_id in self.runs: - if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES): - level = None - else: - level = "ERROR" - - observation = self._detach_observation(run_id) - - if observation is not None: - observation.update( - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - level, - ), - status_message=str(error) if level else None, - input=kwargs.get("inputs"), - ).end() - + if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES): + level = None else: - langfuse_logger.warning( - f"Run ID {run_id} already popped from run map. Could not update run with error message" - ) + level = "ERROR" + + observation = self._detach_observation(run_id) + + if observation is not None: + observation.update( + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + level, + ), + status_message=str(error) if level else None, + input=kwargs.get("inputs"), + ).end() except Exception as e: langfuse_logger.exception(e) @@ -623,29 +601,15 @@ def on_tool_start( serialized, "tool", **kwargs ) - if parent_run_id is None or parent_run_id not in self.runs: - # Create root observation for direct tool calls - span = self.client.start_observation( - name=self.get_langchain_run_name(serialized, **kwargs), - as_type=observation_type, - input=input_str, - metadata=meta, - level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, - ) - - self._attach_observation(run_id, span) - - else: - # Create child observation for tools within chains/agents - span = cast(LangfuseChain, self.runs[parent_run_id]).start_observation( - name=self.get_langchain_run_name(serialized, **kwargs), - as_type=observation_type, - input=input_str, - metadata=meta, - level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, - ) + span = self.client.start_observation( + name=self.get_langchain_run_name(serialized, **kwargs), + as_type=observation_type, + input=input_str, + metadata=meta, + level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, + ) - self._attach_observation(run_id, span) + self._attach_observation(run_id, span) except Exception as e: langfuse_logger.exception(e) @@ -673,35 +637,18 @@ def on_retriever_start( serialized, "retriever", **kwargs ) - if parent_run_id is None: - span = self.client.start_observation( - name=span_name, - as_type=observation_type, - metadata=span_metadata, - input=query, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) - - self._attach_observation(run_id, span) - - else: - span = cast( - LangfuseRetriever, self.runs[parent_run_id] - ).start_observation( - name=span_name, - as_type=observation_type, - input=query, - metadata=span_metadata, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) + span = self.client.start_observation( + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=query, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), + ) - self._attach_observation(run_id, span) + self._attach_observation(run_id, span) except Exception as e: langfuse_logger.exception(e) @@ -816,18 +763,8 @@ def __on_llm_action( "prompt": registered_prompt, } - if parent_run_id is not None and parent_run_id in self.runs: - generation = cast( - LangfuseGeneration, self.runs[parent_run_id] - ).start_observation(as_type="generation", **content) # type: ignore - - self._attach_observation(run_id, generation) - else: - generation = self.client.start_observation( - as_type="generation", **content - ) # type: ignore - - self._attach_observation(run_id, generation) + generation = self.client.start_observation(as_type="generation", **content) + self._attach_observation(run_id, generation) self.last_trace_id = self.runs[run_id].trace_id From a0ce171d3a735c08aa69f2b8612fbf9d068fcfdd Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Tue, 26 Aug 2025 17:53:43 +0200 Subject: [PATCH 3/3] push --- langfuse/langchain/CallbackHandler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index f13647fb8..ed8046e8e 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -763,7 +763,7 @@ def __on_llm_action( "prompt": registered_prompt, } - generation = self.client.start_observation(as_type="generation", **content) + generation = self.client.start_observation(as_type="generation", **content) # type: ignore self._attach_observation(run_id, generation) self.last_trace_id = self.runs[run_id].trace_id