Skip to content

Commit da8e1be

Browse files
authored
Be able to inject trace_context into LangchainCallbackHandler (#1419)
1 parent 2f0da0b commit da8e1be

File tree

1 file changed

+32
-11
lines changed

1 file changed

+32
-11
lines changed

langfuse/langchain/CallbackHandler.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
LangfuseSpan,
2929
LangfuseTool,
3030
)
31+
from langfuse.types import TraceContext
3132
from langfuse._utils import _get_timestamp
3233
from langfuse.langchain.utils import _extract_model_name
3334
from langfuse.logger import langfuse_logger
@@ -92,7 +93,11 @@
9293

9394
class LangchainCallbackHandler(LangchainBaseCallbackHandler):
9495
def __init__(
95-
self, *, public_key: Optional[str] = None, update_trace: bool = False
96+
self,
97+
*,
98+
public_key: Optional[str] = None,
99+
update_trace: bool = False,
100+
trace_context: Optional[TraceContext] = None,
96101
) -> None:
97102
"""Initialize the LangchainCallbackHandler.
98103
@@ -120,6 +125,7 @@ def __init__(
120125

121126
self.last_trace_id: Optional[str] = None
122127
self.update_trace = update_trace
128+
self.trace_context = trace_context
123129

124130
def on_llm_new_token(
125131
self,
@@ -299,16 +305,31 @@ def on_chain_start(
299305
serialized, "chain", **kwargs
300306
)
301307

302-
span = self._get_parent_observation(parent_run_id).start_observation(
303-
name=span_name,
304-
as_type=observation_type,
305-
metadata=span_metadata,
306-
input=inputs,
307-
level=cast(
308-
Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
309-
span_level,
310-
),
311-
)
308+
obs = self._get_parent_observation(parent_run_id)
309+
if isinstance(obs, Langfuse):
310+
span = obs.start_observation(
311+
trace_context=self.trace_context,
312+
name=span_name,
313+
as_type=observation_type,
314+
metadata=span_metadata,
315+
input=inputs,
316+
level=cast(
317+
Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"] | None,
318+
span_level,
319+
),
320+
)
321+
else:
322+
span = obs.start_observation(
323+
name=span_name,
324+
as_type=observation_type,
325+
metadata=span_metadata,
326+
input=inputs,
327+
level=cast(
328+
Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"] | None,
329+
span_level,
330+
),
331+
)
332+
312333
self._attach_observation(run_id, span)
313334

314335
if parent_run_id is None:

0 commit comments

Comments
 (0)