11import asyncio
2+ import contextvars
23import inspect
34import logging
45import os
2122from opentelemetry .util ._decorator import _AgnosticContextManager
2223from typing_extensions import ParamSpec
2324
24- from langfuse ._client .environment_variables import (
25- LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED ,
26- )
27-
2825from langfuse ._client .constants import (
2926 ObservationTypeLiteralNoEvent ,
3027 get_observation_types_list ,
3128)
29+ from langfuse ._client .environment_variables import (
30+ LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED ,
31+ )
3232from langfuse ._client .get_client import _set_current_public_key , get_client
3333from langfuse ._client .span import (
34- LangfuseGeneration ,
35- LangfuseSpan ,
3634 LangfuseAgent ,
37- LangfuseTool ,
3835 LangfuseChain ,
39- LangfuseRetriever ,
40- LangfuseEvaluator ,
4136 LangfuseEmbedding ,
37+ LangfuseEvaluator ,
38+ LangfuseGeneration ,
4239 LangfuseGuardrail ,
40+ LangfuseRetriever ,
41+ LangfuseSpan ,
42+ LangfuseTool ,
4343)
4444from langfuse .types import TraceContext
4545
@@ -468,27 +468,69 @@ def _wrap_sync_generator_result(
468468 generator : Generator ,
469469 transform_to_string : Optional [Callable [[Iterable ], str ]] = None ,
470470 ) -> Any :
471- items = []
472-
473- try :
474- for item in generator :
475- items .append (item )
476-
477- yield item
478-
479- finally :
480- output : Any = items
481-
482- if transform_to_string is not None :
483- output = transform_to_string (items )
484-
485- elif all (isinstance (item , str ) for item in items ):
486- output = "" .join (items )
487-
488- langfuse_span_or_generation .update (output = output )
489- langfuse_span_or_generation .end ()
471+ # Capture the current context while the span is still active
472+ preserved_context = contextvars .copy_context ()
473+ items : list [Any ] = []
474+
475+ class ContextPreservedSyncGeneratorWrapper :
476+ """Sync generator wrapper that ensures each iteration runs in preserved context."""
477+
478+ def __init__ (
479+ self ,
480+ generator : Generator ,
481+ context : contextvars .Context ,
482+ items : list [Any ],
483+ span : Union [
484+ LangfuseSpan ,
485+ LangfuseGeneration ,
486+ LangfuseAgent ,
487+ LangfuseTool ,
488+ LangfuseChain ,
489+ LangfuseRetriever ,
490+ LangfuseEvaluator ,
491+ LangfuseEmbedding ,
492+ LangfuseGuardrail ,
493+ ],
494+ transform_fn : Optional [Callable [[Iterable ], str ]],
495+ ) -> None :
496+ self .generator = generator
497+ self .context = context
498+ self .items = items
499+ self .span = span
500+ self .transform_fn = transform_fn
501+
502+ def __iter__ (self ) -> "ContextPreservedSyncGeneratorWrapper" :
503+ return self
504+
505+ def __next__ (self ) -> Any :
506+ try :
507+ # Run the generator's __next__ in the preserved context
508+ item = self .context .run (next , self .generator )
509+ self .items .append (item )
510+ return item
511+
512+ except StopIteration :
513+ # Handle output and span cleanup when generator is exhausted
514+ output : Any = self .items
515+
516+ if self .transform_fn is not None :
517+ output = self .transform_fn (self .items )
518+ elif all (isinstance (item , str ) for item in self .items ):
519+ output = "" .join (self .items )
520+
521+ self .span .update (output = output )
522+ self .span .end ()
523+ raise # Re-raise StopIteration
524+
525+ return ContextPreservedSyncGeneratorWrapper (
526+ generator ,
527+ preserved_context ,
528+ items ,
529+ langfuse_span_or_generation ,
530+ transform_to_string ,
531+ )
490532
491- async def _wrap_async_generator_result (
533+ def _wrap_async_generator_result (
492534 self ,
493535 langfuse_span_or_generation : Union [
494536 LangfuseSpan ,
@@ -503,26 +545,79 @@ async def _wrap_async_generator_result(
503545 ],
504546 generator : AsyncGenerator ,
505547 transform_to_string : Optional [Callable [[Iterable ], str ]] = None ,
506- ) -> AsyncGenerator :
507- items = []
508-
509- try :
510- async for item in generator :
511- items .append (item )
512-
513- yield item
514-
515- finally :
516- output : Any = items
517-
518- if transform_to_string is not None :
519- output = transform_to_string (items )
520-
521- elif all (isinstance (item , str ) for item in items ):
522- output = "" .join (items )
523-
524- langfuse_span_or_generation .update (output = output )
525- langfuse_span_or_generation .end ()
548+ ) -> Any :
549+ import asyncio
550+
551+ # Capture the current context while the span is still active
552+ preserved_context = contextvars .copy_context ()
553+ items : list [Any ] = []
554+
555+ class ContextPreservedAsyncGeneratorWrapper :
556+ """Async generator wrapper that ensures each iteration runs in preserved context."""
557+
558+ def __init__ (
559+ self ,
560+ generator : AsyncGenerator ,
561+ context : contextvars .Context ,
562+ items : list [Any ],
563+ span : Union [
564+ LangfuseSpan ,
565+ LangfuseGeneration ,
566+ LangfuseAgent ,
567+ LangfuseTool ,
568+ LangfuseChain ,
569+ LangfuseRetriever ,
570+ LangfuseEvaluator ,
571+ LangfuseEmbedding ,
572+ LangfuseGuardrail ,
573+ ],
574+ transform_fn : Optional [Callable [[Iterable ], str ]],
575+ ) -> None :
576+ self .generator = generator
577+ self .context = context
578+ self .items = items
579+ self .span = span
580+ self .transform_fn = transform_fn
581+
582+ def __aiter__ (self ) -> "ContextPreservedAsyncGeneratorWrapper" :
583+ return self
584+
585+ async def __anext__ (self ) -> Any :
586+ try :
587+ # Run the generator's __anext__ in the preserved context
588+ try :
589+ # Python 3.10+ approach with context parameter
590+ item = await asyncio .create_task (
591+ self .generator .__anext__ (), # type: ignore
592+ context = self .context ,
593+ ) # type: ignore
594+ except TypeError :
595+ # Python < 3.10 fallback - context parameter not supported
596+ item = await self .generator .__anext__ ()
597+
598+ self .items .append (item )
599+ return item
600+
601+ except StopAsyncIteration :
602+ # Handle output and span cleanup when generator is exhausted
603+ output : Any = self .items
604+
605+ if self .transform_fn is not None :
606+ output = self .transform_fn (self .items )
607+ elif all (isinstance (item , str ) for item in self .items ):
608+ output = "" .join (self .items )
609+
610+ self .span .update (output = output )
611+ self .span .end ()
612+ raise # Re-raise StopAsyncIteration
613+
614+ return ContextPreservedAsyncGeneratorWrapper (
615+ generator ,
616+ preserved_context ,
617+ items ,
618+ langfuse_span_or_generation ,
619+ transform_to_string ,
620+ )
526621
527622
528623_decorator = LangfuseDecorator ()
0 commit comments