Skip to content

Commit 370f69c

Browse files
committed
fix(observe): handle generator context propagation
1 parent 3ce7abe commit 370f69c

File tree

4 files changed

+1290
-204
lines changed

4 files changed

+1290
-204
lines changed

langfuse/_client/observe.py

Lines changed: 144 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextvars
23
import inspect
34
import logging
45
import os
@@ -21,25 +22,24 @@
2122
from opentelemetry.util._decorator import _AgnosticContextManager
2223
from typing_extensions import ParamSpec
2324

24-
from langfuse._client.environment_variables import (
25-
LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED,
26-
)
27-
2825
from langfuse._client.constants import (
2926
ObservationTypeLiteralNoEvent,
3027
get_observation_types_list,
3128
)
29+
from langfuse._client.environment_variables import (
30+
LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED,
31+
)
3232
from langfuse._client.get_client import _set_current_public_key, get_client
3333
from langfuse._client.span import (
34-
LangfuseGeneration,
35-
LangfuseSpan,
3634
LangfuseAgent,
37-
LangfuseTool,
3835
LangfuseChain,
39-
LangfuseRetriever,
40-
LangfuseEvaluator,
4136
LangfuseEmbedding,
37+
LangfuseEvaluator,
38+
LangfuseGeneration,
4239
LangfuseGuardrail,
40+
LangfuseRetriever,
41+
LangfuseSpan,
42+
LangfuseTool,
4343
)
4444
from langfuse.types import TraceContext
4545

@@ -468,27 +468,69 @@ def _wrap_sync_generator_result(
468468
generator: Generator,
469469
transform_to_string: Optional[Callable[[Iterable], str]] = None,
470470
) -> Any:
471-
items = []
472-
473-
try:
474-
for item in generator:
475-
items.append(item)
476-
477-
yield item
478-
479-
finally:
480-
output: Any = items
481-
482-
if transform_to_string is not None:
483-
output = transform_to_string(items)
484-
485-
elif all(isinstance(item, str) for item in items):
486-
output = "".join(items)
487-
488-
langfuse_span_or_generation.update(output=output)
489-
langfuse_span_or_generation.end()
471+
# Capture the current context while the span is still active
472+
preserved_context = contextvars.copy_context()
473+
items: list[Any] = []
474+
475+
class ContextPreservedSyncGeneratorWrapper:
476+
"""Sync generator wrapper that ensures each iteration runs in preserved context."""
477+
478+
def __init__(
479+
self,
480+
generator: Generator,
481+
context: contextvars.Context,
482+
items: list[Any],
483+
span: Union[
484+
LangfuseSpan,
485+
LangfuseGeneration,
486+
LangfuseAgent,
487+
LangfuseTool,
488+
LangfuseChain,
489+
LangfuseRetriever,
490+
LangfuseEvaluator,
491+
LangfuseEmbedding,
492+
LangfuseGuardrail,
493+
],
494+
transform_fn: Optional[Callable[[Iterable], str]],
495+
) -> None:
496+
self.generator = generator
497+
self.context = context
498+
self.items = items
499+
self.span = span
500+
self.transform_fn = transform_fn
501+
502+
def __iter__(self) -> "ContextPreservedSyncGeneratorWrapper":
503+
return self
504+
505+
def __next__(self) -> Any:
506+
try:
507+
# Run the generator's __next__ in the preserved context
508+
item = self.context.run(next, self.generator)
509+
self.items.append(item)
510+
return item
511+
512+
except StopIteration:
513+
# Handle output and span cleanup when generator is exhausted
514+
output: Any = self.items
515+
516+
if self.transform_fn is not None:
517+
output = self.transform_fn(self.items)
518+
elif all(isinstance(item, str) for item in self.items):
519+
output = "".join(self.items)
520+
521+
self.span.update(output=output)
522+
self.span.end()
523+
raise # Re-raise StopIteration
524+
525+
return ContextPreservedSyncGeneratorWrapper(
526+
generator,
527+
preserved_context,
528+
items,
529+
langfuse_span_or_generation,
530+
transform_to_string,
531+
)
490532

491-
async def _wrap_async_generator_result(
533+
def _wrap_async_generator_result(
492534
self,
493535
langfuse_span_or_generation: Union[
494536
LangfuseSpan,
@@ -503,26 +545,79 @@ async def _wrap_async_generator_result(
503545
],
504546
generator: AsyncGenerator,
505547
transform_to_string: Optional[Callable[[Iterable], str]] = None,
506-
) -> AsyncGenerator:
507-
items = []
508-
509-
try:
510-
async for item in generator:
511-
items.append(item)
512-
513-
yield item
514-
515-
finally:
516-
output: Any = items
517-
518-
if transform_to_string is not None:
519-
output = transform_to_string(items)
520-
521-
elif all(isinstance(item, str) for item in items):
522-
output = "".join(items)
523-
524-
langfuse_span_or_generation.update(output=output)
525-
langfuse_span_or_generation.end()
548+
) -> Any:
549+
import asyncio
550+
551+
# Capture the current context while the span is still active
552+
preserved_context = contextvars.copy_context()
553+
items: list[Any] = []
554+
555+
class ContextPreservedAsyncGeneratorWrapper:
556+
"""Async generator wrapper that ensures each iteration runs in preserved context."""
557+
558+
def __init__(
559+
self,
560+
generator: AsyncGenerator,
561+
context: contextvars.Context,
562+
items: list[Any],
563+
span: Union[
564+
LangfuseSpan,
565+
LangfuseGeneration,
566+
LangfuseAgent,
567+
LangfuseTool,
568+
LangfuseChain,
569+
LangfuseRetriever,
570+
LangfuseEvaluator,
571+
LangfuseEmbedding,
572+
LangfuseGuardrail,
573+
],
574+
transform_fn: Optional[Callable[[Iterable], str]],
575+
) -> None:
576+
self.generator = generator
577+
self.context = context
578+
self.items = items
579+
self.span = span
580+
self.transform_fn = transform_fn
581+
582+
def __aiter__(self) -> "ContextPreservedAsyncGeneratorWrapper":
583+
return self
584+
585+
async def __anext__(self) -> Any:
586+
try:
587+
# Run the generator's __anext__ in the preserved context
588+
try:
589+
# Python 3.10+ approach with context parameter
590+
item = await asyncio.create_task(
591+
self.generator.__anext__(), # type: ignore
592+
context=self.context,
593+
) # type: ignore
594+
except TypeError:
595+
# Python < 3.10 fallback - context parameter not supported
596+
item = await self.generator.__anext__()
597+
598+
self.items.append(item)
599+
return item
600+
601+
except StopAsyncIteration:
602+
# Handle output and span cleanup when generator is exhausted
603+
output: Any = self.items
604+
605+
if self.transform_fn is not None:
606+
output = self.transform_fn(self.items)
607+
elif all(isinstance(item, str) for item in self.items):
608+
output = "".join(self.items)
609+
610+
self.span.update(output=output)
611+
self.span.end()
612+
raise # Re-raise StopAsyncIteration
613+
614+
return ContextPreservedAsyncGeneratorWrapper(
615+
generator,
616+
preserved_context,
617+
items,
618+
langfuse_span_or_generation,
619+
transform_to_string,
620+
)
526621

527622

528623
_decorator = LangfuseDecorator()

0 commit comments

Comments
 (0)