diff --git a/langfuse/_client/observe.py b/langfuse/_client/observe.py index ce848e04a..6b9d52278 100644 --- a/langfuse/_client/observe.py +++ b/langfuse/_client/observe.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import inspect import logging import os @@ -10,6 +11,7 @@ Dict, Generator, Iterable, + List, Optional, Tuple, TypeVar, @@ -21,25 +23,24 @@ from opentelemetry.util._decorator import _AgnosticContextManager from typing_extensions import ParamSpec -from langfuse._client.environment_variables import ( - LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED, -) - from langfuse._client.constants import ( ObservationTypeLiteralNoEvent, get_observation_types_list, ) +from langfuse._client.environment_variables import ( + LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED, +) from langfuse._client.get_client import _set_current_public_key, get_client from langfuse._client.span import ( - LangfuseGeneration, - LangfuseSpan, LangfuseAgent, - LangfuseTool, LangfuseChain, - LangfuseRetriever, - LangfuseEvaluator, LangfuseEmbedding, + LangfuseEvaluator, + LangfuseGeneration, LangfuseGuardrail, + LangfuseRetriever, + LangfuseSpan, + LangfuseTool, ) from langfuse.types import TraceContext @@ -468,29 +469,54 @@ def _wrap_sync_generator_result( generator: Generator, transform_to_string: Optional[Callable[[Iterable], str]] = None, ) -> Any: - items = [] + preserved_context = contextvars.copy_context() - try: - for item in generator: - items.append(item) + return _ContextPreservedSyncGeneratorWrapper( + generator, + preserved_context, + langfuse_span_or_generation, + transform_to_string, + ) + + def _wrap_async_generator_result( + self, + langfuse_span_or_generation: Union[ + LangfuseSpan, + LangfuseGeneration, + LangfuseAgent, + LangfuseTool, + LangfuseChain, + LangfuseRetriever, + LangfuseEvaluator, + LangfuseEmbedding, + LangfuseGuardrail, + ], + generator: AsyncGenerator, + transform_to_string: Optional[Callable[[Iterable], str]] = None, + ) -> Any: + preserved_context = contextvars.copy_context() - yield item + return _ContextPreservedAsyncGeneratorWrapper( + generator, + preserved_context, + langfuse_span_or_generation, + transform_to_string, + ) - finally: - output: Any = items - if transform_to_string is not None: - output = transform_to_string(items) +_decorator = LangfuseDecorator() + +observe = _decorator.observe - elif all(isinstance(item, str) for item in items): - output = "".join(items) - langfuse_span_or_generation.update(output=output) - langfuse_span_or_generation.end() +class _ContextPreservedSyncGeneratorWrapper: + """Sync generator wrapper that ensures each iteration runs in preserved context.""" - async def _wrap_async_generator_result( + def __init__( self, - langfuse_span_or_generation: Union[ + generator: Generator, + context: contextvars.Context, + span: Union[ LangfuseSpan, LangfuseGeneration, LangfuseAgent, @@ -501,30 +527,105 @@ async def _wrap_async_generator_result( LangfuseEmbedding, LangfuseGuardrail, ], - generator: AsyncGenerator, - transform_to_string: Optional[Callable[[Iterable], str]] = None, - ) -> AsyncGenerator: - items = [] + transform_fn: Optional[Callable[[Iterable], str]], + ) -> None: + self.generator = generator + self.context = context + self.items: List[Any] = [] + self.span = span + self.transform_fn = transform_fn + + def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper": + return self + + def __next__(self) -> Any: + try: + # Run the generator's __next__ in the preserved context + item = self.context.run(next, self.generator) + self.items.append(item) + + return item + + except StopIteration: + # Handle output and span cleanup when generator is exhausted + output: Any = self.items + + if self.transform_fn is not None: + output = self.transform_fn(self.items) + + elif all(isinstance(item, str) for item in self.items): + output = "".join(self.items) + + self.span.update(output=output).end() + + raise # Re-raise StopIteration + + except Exception as e: + self.span.update(level="ERROR", status_message=str(e)).end() + raise + + +class _ContextPreservedAsyncGeneratorWrapper: + """Async generator wrapper that ensures each iteration runs in preserved context.""" + + def __init__( + self, + generator: AsyncGenerator, + context: contextvars.Context, + span: Union[ + LangfuseSpan, + LangfuseGeneration, + LangfuseAgent, + LangfuseTool, + LangfuseChain, + LangfuseRetriever, + LangfuseEvaluator, + LangfuseEmbedding, + LangfuseGuardrail, + ], + transform_fn: Optional[Callable[[Iterable], str]], + ) -> None: + self.generator = generator + self.context = context + self.items: List[Any] = [] + self.span = span + self.transform_fn = transform_fn + + def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper": + return self + + async def __anext__(self) -> Any: try: - async for item in generator: - items.append(item) + # Run the generator's __anext__ in the preserved context + try: + # Python 3.10+ approach with context parameter + item = await asyncio.create_task( + self.generator.__anext__(), # type: ignore + context=self.context, + ) # type: ignore + except TypeError: + # Python < 3.10 fallback - context parameter not supported + item = await self.generator.__anext__() - yield item + self.items.append(item) - finally: - output: Any = items + return item - if transform_to_string is not None: - output = transform_to_string(items) + except StopAsyncIteration: + # Handle output and span cleanup when generator is exhausted + output: Any = self.items - elif all(isinstance(item, str) for item in items): - output = "".join(items) + if self.transform_fn is not None: + output = self.transform_fn(self.items) - langfuse_span_or_generation.update(output=output) - langfuse_span_or_generation.end() + elif all(isinstance(item, str) for item in self.items): + output = "".join(self.items) + self.span.update(output=output).end() -_decorator = LangfuseDecorator() + raise # Re-raise StopAsyncIteration + except Exception as e: + self.span.update(level="ERROR", status_message=str(e)).end() -observe = _decorator.observe + raise diff --git a/tests/test_decorators.py b/tests/test_decorators.py index fe0a7f4c3..5803d531b 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -1,5 +1,6 @@ import asyncio import os +import sys from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from time import sleep @@ -8,6 +9,7 @@ import pytest from langchain.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI +from opentelemetry import trace from langfuse import Langfuse, get_client, observe from langfuse._client.environment_variables import LANGFUSE_PUBLIC_KEY @@ -1686,3 +1688,282 @@ async def async_root_function(*args, **kwargs): # Reset instances to not leak to other test suites removeMockResourceManagerInstances() + + +def test_sync_generator_context_preservation(): + """Test that sync generators preserve context when consumed later (e.g., by streaming responses)""" + langfuse = get_client() + mock_trace_id = langfuse.create_trace_id() + + # Global variable to capture span information + span_info = {} + + @observe(name="sync_generator") + def create_generator(): + current_span = trace.get_current_span() + span_info["generator_span_id"] = trace.format_span_id( + current_span.get_span_context().span_id + ) + + for i in range(3): + yield f"item_{i}" + + @observe(name="root") + def root_function(): + current_span = trace.get_current_span() + span_info["root_span_id"] = trace.format_span_id( + current_span.get_span_context().span_id + ) + + # Return generator without consuming it (like FastAPI StreamingResponse would) + return create_generator() + + # Simulate the scenario where generator is consumed after root function exits + generator = root_function(langfuse_trace_id=mock_trace_id) + + # Consume generator later (like FastAPI would) + items = list(generator) + + langfuse.flush() + + # Verify results + assert items == ["item_0", "item_1", "item_2"] + assert ( + span_info["generator_span_id"] != "0000000000000000" + ), "Generator context should be preserved" + assert ( + span_info["root_span_id"] != span_info["generator_span_id"] + ), "Should have different span IDs" + + # Verify trace structure + trace_data = get_api().trace.get(mock_trace_id) + assert len(trace_data.observations) == 2 + + # Verify both observations are present + observation_names = [obs.name for obs in trace_data.observations] + assert "root" in observation_names + assert "sync_generator" in observation_names + + # Verify generator observation has output + generator_obs = next( + obs for obs in trace_data.observations if obs.name == "sync_generator" + ) + assert generator_obs.output == "item_0item_1item_2" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher") +async def test_async_generator_context_preservation(): + """Test that async generators preserve context when consumed later (e.g., by streaming responses)""" + langfuse = get_client() + mock_trace_id = langfuse.create_trace_id() + + # Global variable to capture span information + span_info = {} + + @observe(name="async_generator") + async def create_async_generator(): + current_span = trace.get_current_span() + span_info["generator_span_id"] = trace.format_span_id( + current_span.get_span_context().span_id + ) + + for i in range(3): + await asyncio.sleep(0.001) # Simulate async work + yield f"async_item_{i}" + + @observe(name="root") + async def root_function(): + current_span = trace.get_current_span() + span_info["root_span_id"] = trace.format_span_id( + current_span.get_span_context().span_id + ) + + # Return generator without consuming it (like FastAPI StreamingResponse would) + return create_async_generator() + + # Simulate the scenario where generator is consumed after root function exits + generator = await root_function(langfuse_trace_id=mock_trace_id) + + # Consume generator later (like FastAPI would) + items = [] + async for item in generator: + items.append(item) + + langfuse.flush() + + # Verify results + assert items == ["async_item_0", "async_item_1", "async_item_2"] + assert ( + span_info["generator_span_id"] != "0000000000000000" + ), "Generator context should be preserved" + assert ( + span_info["root_span_id"] != span_info["generator_span_id"] + ), "Should have different span IDs" + + # Verify trace structure + trace_data = get_api().trace.get(mock_trace_id) + assert len(trace_data.observations) == 2 + + # Verify both observations are present + observation_names = [obs.name for obs in trace_data.observations if obs.name] + assert "root" in observation_names + assert "async_generator" in observation_names + + # Verify generator observation has output + generator_obs = next( + obs for obs in trace_data.observations if obs.name == "async_generator" + ) + assert generator_obs.output == "async_item_0async_item_1async_item_2" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher") +async def test_async_generator_context_preservation_with_trace_hierarchy(): + """Test that async generators maintain proper parent-child span relationships""" + langfuse = get_client() + mock_trace_id = langfuse.create_trace_id() + + # Global variables to capture span information + span_info = {} + + @observe(name="child_stream") + async def child_generator(): + current_span = trace.get_current_span() + span_context = current_span.get_span_context() + span_info["child_span_id"] = trace.format_span_id(span_context.span_id) + span_info["child_trace_id"] = trace.format_trace_id(span_context.trace_id) + + for i in range(2): + await asyncio.sleep(0.001) + yield f"child_{i}" + + @observe(name="parent_root") + async def parent_function(): + current_span = trace.get_current_span() + span_context = current_span.get_span_context() + span_info["parent_span_id"] = trace.format_span_id(span_context.span_id) + span_info["parent_trace_id"] = trace.format_trace_id(span_context.trace_id) + + # Create and return child generator + return child_generator() + + # Execute parent function + generator = await parent_function(langfuse_trace_id=mock_trace_id) + + # Consume generator (simulating delayed consumption) + items = [item async for item in generator] + + langfuse.flush() + + # Verify results + assert items == ["child_0", "child_1"] + + # Verify span hierarchy + assert ( + span_info["parent_span_id"] != span_info["child_span_id"] + ), "Parent and child should have different span IDs" + assert ( + span_info["parent_trace_id"] == span_info["child_trace_id"] + ), "Parent and child should share same trace ID" + assert ( + span_info["child_span_id"] != "0000000000000000" + ), "Child context should be preserved" + + # Verify trace structure + trace_data = get_api().trace.get(mock_trace_id) + assert len(trace_data.observations) == 2 + + # Check both observations exist + observation_names = [obs.name for obs in trace_data.observations if obs.name] + assert "parent_root" in observation_names + assert "child_stream" in observation_names + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher") +async def test_async_generator_exception_handling_with_context(): + """Test that exceptions in async generators are properly handled while preserving context""" + langfuse = get_client() + mock_trace_id = langfuse.create_trace_id() + + @observe(name="failing_generator") + async def failing_generator(): + current_span = trace.get_current_span() + # Verify we have valid context even when exception occurs + assert ( + trace.format_span_id(current_span.get_span_context().span_id) + != "0000000000000000" + ) + + yield "first_item" + await asyncio.sleep(0.001) + raise ValueError("Generator failure test") + yield "never_reached" # This should never execute + + @observe(name="root") + async def root_function(): + return failing_generator() + + # Execute and consume generator + generator = await root_function(langfuse_trace_id=mock_trace_id) + + items = [] + with pytest.raises(ValueError, match="Generator failure test"): + async for item in generator: + items.append(item) + + langfuse.flush() + + # Verify partial results + assert items == ["first_item"] + + # Verify trace structure - should have both observations despite exception + trace_data = get_api().trace.get(mock_trace_id) + assert len(trace_data.observations) == 2 + + # Check that the failing generator observation has ERROR level + failing_obs = next( + obs for obs in trace_data.observations if obs.name == "failing_generator" + ) + assert failing_obs.level == "ERROR" + assert "Generator failure test" in failing_obs.status_message + + +def test_sync_generator_empty_context_preservation(): + """Test that empty sync generators work correctly with context preservation""" + langfuse = get_client() + mock_trace_id = langfuse.create_trace_id() + + @observe(name="empty_generator") + def empty_generator(): + current_span = trace.get_current_span() + # Should have valid context even for empty generator + assert ( + trace.format_span_id(current_span.get_span_context().span_id) + != "0000000000000000" + ) + return + yield # Unreachable + + @observe(name="root") + def root_function(): + return empty_generator() + + generator = root_function(langfuse_trace_id=mock_trace_id) + items = list(generator) + + langfuse.flush() + + # Verify results + assert items == [] + + # Verify trace structure + trace_data = get_api().trace.get(mock_trace_id) + assert len(trace_data.observations) == 2 + + # Verify empty generator observation + empty_obs = next( + obs for obs in trace_data.observations if obs.name == "empty_generator" + ) + assert empty_obs.output is None