Skip to content

Commit 51dfe95

Browse files
committed
add test
1 parent 7630423 commit 51dfe95

File tree

2 files changed

+57
-28
lines changed

2 files changed

+57
-28
lines changed

langfuse/_client/observe.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,12 @@ async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
308308
):
309309
is_return_type_iterator = True
310310

311-
result.body_iterator = (
312-
self._wrap_async_iterator_result(
313-
langfuse_span_or_generation,
314-
result.body_iterator,
315-
transform_to_string,
316-
)
311+
result.body_iterator = self._wrap_async_iterator_result(
312+
langfuse_span_or_generation,
313+
result.body_iterator,
314+
transform_to_string,
317315
)
318-
316+
319317
if isinstance(result, AsyncIterator):
320318
is_return_type_iterator = True
321319

@@ -435,14 +433,11 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
435433
):
436434
is_return_type_iterator = True
437435

438-
result.body_iterator = (
439-
self._wrap_async_iterator_result(
440-
langfuse_span_or_generation,
441-
result.body_iterator,
442-
transform_to_string,
443-
)
436+
result.body_iterator = self._wrap_async_iterator_result(
437+
langfuse_span_or_generation,
438+
result.body_iterator,
439+
transform_to_string,
444440
)
445-
446441

447442
langfuse_span_or_generation.update(output=result)
448443

tests/test_decorators.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,26 +1751,60 @@ def root_function():
17511751
assert generator_obs.output == "item_0item_1item_2"
17521752

17531753

1754+
@pytest.fixture(params=["generator", "iterator"])
1755+
def async_iterable_factory(request):
1756+
"""Factory that creates either an async generator or async iterator"""
1757+
iterable_type = request.param
1758+
1759+
if iterable_type == "generator":
1760+
1761+
async def create_async_generator():
1762+
for i in range(3):
1763+
await asyncio.sleep(0.001)
1764+
yield f"async_item_{i}"
1765+
1766+
return create_async_generator
1767+
else: # iterator
1768+
1769+
class AIter:
1770+
def __init__(self):
1771+
self.index = -1
1772+
1773+
def __aiter__(self):
1774+
return self
1775+
1776+
async def __anext__(self):
1777+
if self.index < 2:
1778+
await asyncio.sleep(0.001)
1779+
self.index += 1
1780+
return f"async_item_{self.index}"
1781+
else:
1782+
raise StopAsyncIteration
1783+
1784+
def create_async_iterator():
1785+
return AIter()
1786+
1787+
return create_async_iterator
1788+
1789+
17541790
@pytest.mark.asyncio
17551791
@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher")
1756-
async def test_async_generator_context_preservation():
1757-
"""Test that async generators preserve context when consumed later (e.g., by streaming responses)"""
1792+
async def test_async_generator_context_preservation(async_iterable_factory):
1793+
"""Test that async generators and iterators preserve context when consumed later (e.g., by streaming responses)"""
17581794
langfuse = get_client()
17591795
mock_trace_id = langfuse.create_trace_id()
17601796

17611797
# Global variable to capture span information
17621798
span_info = {}
17631799

17641800
@observe(name="async_generator")
1765-
async def create_async_generator():
1801+
async def create_async_iterable():
17661802
current_span = trace.get_current_span()
17671803
span_info["generator_span_id"] = trace.format_span_id(
17681804
current_span.get_span_context().span_id
17691805
)
17701806

1771-
for i in range(3):
1772-
await asyncio.sleep(0.001) # Simulate async work
1773-
yield f"async_item_{i}"
1807+
return async_iterable_factory()
17741808

17751809
@observe(name="root")
17761810
async def root_function():
@@ -1779,23 +1813,23 @@ async def root_function():
17791813
current_span.get_span_context().span_id
17801814
)
17811815

1782-
# Return generator without consuming it (like FastAPI StreamingResponse would)
1783-
return create_async_generator()
1816+
# Return iterable without consuming it (like FastAPI StreamingResponse would)
1817+
return await create_async_iterable()
17841818

1785-
# Simulate the scenario where generator is consumed after root function exits
1786-
generator = await root_function(langfuse_trace_id=mock_trace_id)
1819+
# Simulate the scenario where iterable is consumed after root function exits
1820+
iterable = await root_function(langfuse_trace_id=mock_trace_id)
17871821

1788-
# Consume generator later (like FastAPI would)
1822+
# Consume iterable later (like FastAPI would)
17891823
items = []
1790-
async for item in generator:
1824+
async for item in iterable:
17911825
items.append(item)
17921826

17931827
langfuse.flush()
17941828

17951829
# Verify results
17961830
assert items == ["async_item_0", "async_item_1", "async_item_2"]
17971831
assert span_info["generator_span_id"] != "0000000000000000", (
1798-
"Generator context should be preserved"
1832+
"Context should be preserved"
17991833
)
18001834
assert span_info["root_span_id"] != span_info["generator_span_id"], (
18011835
"Should have different span IDs"
@@ -1810,7 +1844,7 @@ async def root_function():
18101844
assert "root" in observation_names
18111845
assert "async_generator" in observation_names
18121846

1813-
# Verify generator observation has output
1847+
# Verify observation has output
18141848
generator_obs = next(
18151849
obs for obs in trace_data.observations if obs.name == "async_generator"
18161850
)

0 commit comments

Comments
 (0)