@@ -1751,26 +1751,60 @@ def root_function():
17511751 assert generator_obs .output == "item_0item_1item_2"
17521752
17531753
1754+ @pytest .fixture (params = ["generator" , "iterator" ])
1755+ def async_iterable_factory (request ):
1756+ """Factory that creates either an async generator or async iterator"""
1757+ iterable_type = request .param
1758+
1759+ if iterable_type == "generator" :
1760+
1761+ async def create_async_generator ():
1762+ for i in range (3 ):
1763+ await asyncio .sleep (0.001 )
1764+ yield f"async_item_{ i } "
1765+
1766+ return create_async_generator
1767+ else : # iterator
1768+
1769+ class AIter :
1770+ def __init__ (self ):
1771+ self .index = - 1
1772+
1773+ def __aiter__ (self ):
1774+ return self
1775+
1776+ async def __anext__ (self ):
1777+ if self .index < 2 :
1778+ await asyncio .sleep (0.001 )
1779+ self .index += 1
1780+ return f"async_item_{ self .index } "
1781+ else :
1782+ raise StopAsyncIteration
1783+
1784+ def create_async_iterator ():
1785+ return AIter ()
1786+
1787+ return create_async_iterator
1788+
1789+
17541790@pytest .mark .asyncio
17551791@pytest .mark .skipif (sys .version_info < (3 , 11 ), reason = "requires python3.11 or higher" )
1756- async def test_async_generator_context_preservation ():
1757- """Test that async generators preserve context when consumed later (e.g., by streaming responses)"""
1792+ async def test_async_generator_context_preservation (async_iterable_factory ):
1793+ """Test that async generators and iterators preserve context when consumed later (e.g., by streaming responses)"""
17581794 langfuse = get_client ()
17591795 mock_trace_id = langfuse .create_trace_id ()
17601796
17611797 # Global variable to capture span information
17621798 span_info = {}
17631799
17641800 @observe (name = "async_generator" )
1765- async def create_async_generator ():
1801+ async def create_async_iterable ():
17661802 current_span = trace .get_current_span ()
17671803 span_info ["generator_span_id" ] = trace .format_span_id (
17681804 current_span .get_span_context ().span_id
17691805 )
17701806
1771- for i in range (3 ):
1772- await asyncio .sleep (0.001 ) # Simulate async work
1773- yield f"async_item_{ i } "
1807+ return async_iterable_factory ()
17741808
17751809 @observe (name = "root" )
17761810 async def root_function ():
@@ -1779,23 +1813,23 @@ async def root_function():
17791813 current_span .get_span_context ().span_id
17801814 )
17811815
1782- # Return generator without consuming it (like FastAPI StreamingResponse would)
1783- return create_async_generator ()
1816+ # Return iterable without consuming it (like FastAPI StreamingResponse would)
1817+ return await create_async_iterable ()
17841818
1785- # Simulate the scenario where generator is consumed after root function exits
1786- generator = await root_function (langfuse_trace_id = mock_trace_id )
1819+ # Simulate the scenario where iterable is consumed after root function exits
1820+ iterable = await root_function (langfuse_trace_id = mock_trace_id )
17871821
1788- # Consume generator later (like FastAPI would)
1822+ # Consume iterable later (like FastAPI would)
17891823 items = []
1790- async for item in generator :
1824+ async for item in iterable :
17911825 items .append (item )
17921826
17931827 langfuse .flush ()
17941828
17951829 # Verify results
17961830 assert items == ["async_item_0" , "async_item_1" , "async_item_2" ]
17971831 assert span_info ["generator_span_id" ] != "0000000000000000" , (
1798- "Generator context should be preserved"
1832+ "Context should be preserved"
17991833 )
18001834 assert span_info ["root_span_id" ] != span_info ["generator_span_id" ], (
18011835 "Should have different span IDs"
@@ -1810,7 +1844,7 @@ async def root_function():
18101844 assert "root" in observation_names
18111845 assert "async_generator" in observation_names
18121846
1813- # Verify generator observation has output
1847+ # Verify observation has output
18141848 generator_obs = next (
18151849 obs for obs in trace_data .observations if obs .name == "async_generator"
18161850 )
0 commit comments