diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 7c898807c..ba2460c47 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -56,7 +56,15 @@ class LangchainCallbackHandler(LangchainBaseCallbackHandler): - def __init__(self, *, public_key: Optional[str] = None) -> None: + def __init__( + self, *, public_key: Optional[str] = None, update_trace: bool = False + ) -> None: + """Initialize the LangchainCallbackHandler. + + Args: + public_key: Optional Langfuse public key. If not provided, will use the default client configuration. + update_trace: Whether to update the Langfuse trace with the chains input / output / metadata / name. Defaults to False. + """ self.client = get_client(public_key=public_key) self.runs: Dict[UUID, Union[LangfuseSpan, LangfuseGeneration]] = {} @@ -64,6 +72,7 @@ def __init__(self, *, public_key: Optional[str] = None) -> None: self.updated_completion_start_time_memo: Set[UUID] = set() self.last_trace_id: Optional[str] = None + self.update_trace = update_trace def on_llm_new_token( self, @@ -207,7 +216,19 @@ def on_chain_start( ), ) span.update_trace( - **self._parse_langfuse_trace_attributes_from_metadata(metadata) + **( + cast( + Any, + { + "input": inputs, + "name": span_name, + "metadata": span_metadata, + }, + ) + if self.update_trace + else {} + ), + **self._parse_langfuse_trace_attributes_from_metadata(metadata), ) self.runs[run_id] = span else: @@ -322,14 +343,21 @@ def on_chain_end( if run_id not in self.runs: raise Exception("run not found") - self.runs[run_id].update( + span = self.runs[run_id] + span.update( output=outputs, input=kwargs.get("inputs"), - ).end() + ) + + if parent_run_id is None and self.update_trace: + span.update_trace(output=outputs, input=kwargs.get("inputs")) + + span.end() del self.runs[run_id] self._deregister_langfuse_prompt(run_id) + except Exception as e: langfuse_logger.exception(e)