|
28 | 28 | LangfuseSpan, |
29 | 29 | LangfuseTool, |
30 | 30 | ) |
| 31 | +from langfuse.types import TraceContext |
31 | 32 | from langfuse._utils import _get_timestamp |
32 | 33 | from langfuse.langchain.utils import _extract_model_name |
33 | 34 | from langfuse.logger import langfuse_logger |
|
92 | 93 |
|
93 | 94 | class LangchainCallbackHandler(LangchainBaseCallbackHandler): |
94 | 95 | def __init__( |
95 | | - self, *, public_key: Optional[str] = None, update_trace: bool = False |
| 96 | + self, |
| 97 | + *, |
| 98 | + public_key: Optional[str] = None, |
| 99 | + update_trace: bool = False, |
| 100 | + trace_context: Optional[TraceContext] = None, |
96 | 101 | ) -> None: |
97 | 102 | """Initialize the LangchainCallbackHandler. |
98 | 103 |
|
@@ -120,6 +125,7 @@ def __init__( |
120 | 125 |
|
121 | 126 | self.last_trace_id: Optional[str] = None |
122 | 127 | self.update_trace = update_trace |
| 128 | + self.trace_context = trace_context |
123 | 129 |
|
124 | 130 | def on_llm_new_token( |
125 | 131 | self, |
@@ -299,16 +305,31 @@ def on_chain_start( |
299 | 305 | serialized, "chain", **kwargs |
300 | 306 | ) |
301 | 307 |
|
302 | | - span = self._get_parent_observation(parent_run_id).start_observation( |
303 | | - name=span_name, |
304 | | - as_type=observation_type, |
305 | | - metadata=span_metadata, |
306 | | - input=inputs, |
307 | | - level=cast( |
308 | | - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], |
309 | | - span_level, |
310 | | - ), |
311 | | - ) |
| 308 | + obs = self._get_parent_observation(parent_run_id) |
| 309 | + if isinstance(obs, Langfuse): |
| 310 | + span = obs.start_observation( |
| 311 | + trace_context=self.trace_context, |
| 312 | + name=span_name, |
| 313 | + as_type=observation_type, |
| 314 | + metadata=span_metadata, |
| 315 | + input=inputs, |
| 316 | + level=cast( |
| 317 | + Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"] | None, |
| 318 | + span_level, |
| 319 | + ), |
| 320 | + ) |
| 321 | + else: |
| 322 | + span = obs.start_observation( |
| 323 | + name=span_name, |
| 324 | + as_type=observation_type, |
| 325 | + metadata=span_metadata, |
| 326 | + input=inputs, |
| 327 | + level=cast( |
| 328 | + Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"] | None, |
| 329 | + span_level, |
| 330 | + ), |
| 331 | + ) |
| 332 | + |
312 | 333 | self._attach_observation(run_id, span) |
313 | 334 |
|
314 | 335 | if parent_run_id is None: |
|
0 commit comments