From a9b80d8d32b410096a1cd2e47ac438c31303131e Mon Sep 17 00:00:00 2001 From: Chris Mullins Date: Mon, 20 Oct 2025 15:56:43 -0700 Subject: [PATCH 1/4] Implement queue locking mechanism in StreamingResponse to prevent race conditions during activity processing --- .../app/streaming/streaming_response.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py index 4c1cff5b..95eba677 100644 --- a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py +++ b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py @@ -55,6 +55,7 @@ def __init__(self, context: "TurnContext"): # Queue for outgoing activities self._queue: List[Callable[[], Activity]] = [] self._queue_sync: Optional[asyncio.Task] = None + self._queue_lock = asyncio.Lock() self._chunk_queued = False # Powered by AI feature flags @@ -274,7 +275,6 @@ def _queue_next_chunk(self) -> None: self._chunk_queued = True def create_activity(): - self._chunk_queued = False if self._ended: # Send final message activity = Activity( @@ -303,8 +303,10 @@ def create_activity(): ], ) else: - return + self._chunk_queued = False + return None self._sequence_number += 1 + self._chunk_queued = False return activity self._queue_activity(create_activity) @@ -316,8 +318,14 @@ def _queue_activity(self, factory: Callable[[], Activity]) -> None: self._queue.append(factory) # If there's no sync in progress, start one - if not self._queue_sync: - self._queue_sync = asyncio.create_task(self._drain_queue()) + # Use a lock to prevent race conditions when checking/starting the drain task + async def start_drain_if_needed(): + async with self._queue_lock: + if not self._queue_sync or self._queue_sync.done(): + self._queue_sync = asyncio.create_task(self._drain_queue()) + + # Schedule the coroutine to run + asyncio.create_task(start_drain_if_needed()) async def _drain_queue(self) -> None: """ @@ -326,7 +334,12 @@ async def _drain_queue(self) -> None: try: logger.debug(f"Draining queue with {len(self._queue)} activities.") while self._queue: - factory = self._queue.pop(0) + # Use lock to safely access the queue + async with self._queue_lock: + if not self._queue: + break + factory = self._queue.pop(0) + activity = factory() if activity: await self._send_activity(activity) @@ -343,7 +356,8 @@ async def _drain_queue(self) -> None: ) raise finally: - self._queue_sync = None + async with self._queue_lock: + self._queue_sync = None async def _send_activity(self, activity: Activity) -> None: """ From 190bf66841fef75e7a26e40f56c0346c35494fb8 Mon Sep 17 00:00:00 2001 From: Chris Mullins Date: Mon, 20 Oct 2025 15:59:05 -0700 Subject: [PATCH 2/4] Fix formatting issue in _drain_queue method for improved readability --- .../hosting/aiohttp/app/streaming/streaming_response.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py index 95eba677..b636f88f 100644 --- a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py +++ b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py @@ -339,7 +339,7 @@ async def _drain_queue(self) -> None: if not self._queue: break factory = self._queue.pop(0) - + activity = factory() if activity: await self._send_activity(activity) From e8be1b7dd91fda9e40c11ae8265a4818740d2945 Mon Sep 17 00:00:00 2001 From: Chris Mullins Date: Tue, 21 Oct 2025 13:43:54 -0700 Subject: [PATCH 3/4] Refactor StreamingResponse to use asyncio.Queue for activity management and improve concurrency handling Note: I'm not convinced this is the right direction. I've converted the PR to draft, and we should re-evaluate. --- .../app/streaming/streaming_response.py | 368 +++++++---- test_samples/app_style/streaming_agent.py | 6 +- .../test_streaming_response.py | 572 ++++++++++++++++++ 3 files changed, 817 insertions(+), 129 deletions(-) create mode 100644 tests/hosting_aiohttp/test_streaming_response.py diff --git a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py index b636f88f..3441019c 100644 --- a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py +++ b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py @@ -3,6 +3,7 @@ import asyncio import logging +import threading from typing import List, Optional, Callable, Literal, TYPE_CHECKING from dataclasses import dataclass @@ -53,9 +54,14 @@ def __init__(self, context: "TurnContext"): self._cancelled = False # Queue for outgoing activities - self._queue: List[Callable[[], Activity]] = [] - self._queue_sync: Optional[asyncio.Task] = None - self._queue_lock = asyncio.Lock() + self._queue: asyncio.Queue[Callable[[], Activity]] = asyncio.Queue() + self._drain_task: Optional[asyncio.Task] = None + self._drain_task_lock = asyncio.Lock() + + # State lock to protect shared mutable state (for async operations) + self._state_lock = asyncio.Lock() + # Sync lock for message text updates (can be called from sync context) + self._message_lock = threading.Lock() self._chunk_queued = False # Powered by AI feature flags @@ -76,17 +82,24 @@ def stream_id(self) -> Optional[str]: """ Gets the stream ID of the current response. Assigned after the initial update is sent. + Note: Returns a snapshot; may be stale if called during concurrent updates. """ return self._stream_id @property def citations(self) -> Optional[List[ClientCitation]]: - """Gets the citations of the current response.""" + """ + Gets the citations of the current response. + Note: Returns reference to internal list; do not modify directly. + """ return self._citations @property def updates_sent(self) -> int: - """Gets the number of updates sent for the stream.""" + """ + Gets the number of updates sent for the stream. + Note: Returns a snapshot; may be stale if called during concurrent updates. + """ return self._sequence_number - 1 def queue_informative_update(self, text: str) -> None: @@ -102,8 +115,12 @@ def queue_informative_update(self, text: str) -> None: if self._ended: raise RuntimeError("The stream has already ended.") - # Queue a typing activity - def create_activity(): + # Queue a typing activity - capture sequence number atomically + async def create_activity_async(): + async with self._state_lock: + seq_num = self._sequence_number + self._sequence_number += 1 + activity = Activity( type="typing", text=text, @@ -111,14 +128,13 @@ def create_activity(): Entity( type="streaminfo", stream_type="informative", - stream_sequence=self._sequence_number, + stream_sequence=seq_num, ) ], ) - self._sequence_number += 1 return activity - self._queue_activity(create_activity) + self._queue_activity(create_activity_async) def queue_text_chunk( self, text: str, citations: Optional[List[Citation]] = None @@ -133,33 +149,38 @@ def queue_text_chunk( text: Partial text of the message to send. citations: Citations to be included in the message. """ - if self._cancelled: - return - if self._ended: - raise RuntimeError("The stream has already ended.") - - # Update full message text - self._message += text + # Update message text synchronously under thread lock + with self._message_lock: + if self._cancelled: + return + if self._ended: + raise RuntimeError("The stream has already ended.") - # If there are citations, modify the content so that the sources are numbers instead of [doc1], [doc2], etc. - self._message = CitationUtil.format_citations_response(self._message) + # Update full message text atomically + self._message += text + # If there are citations, modify the content so that the sources are numbers instead of [doc1], [doc2], etc. + self._message = CitationUtil.format_citations_response(self._message) - # Queue the next chunk - self._queue_next_chunk() + # Schedule the async queueing work in background + asyncio.create_task(self._queue_next_chunk()) async def end_stream(self) -> None: """ Ends the stream by sending the final message to the client. """ - if self._ended: - raise RuntimeError("The stream has already ended.") - - # Queue final message - self._ended = True - self._queue_next_chunk() + async with self._state_lock: + if self._ended: + raise RuntimeError("The stream has already ended.") + # Queue final message + self._ended = True + + await self._queue_next_chunk() # Wait for the queue to drain await self.wait_for_queue() + + # Clean up any remaining tasks + await self.cleanup() def set_attachments(self, attachments: List[Attachment]) -> None: """ @@ -185,25 +206,35 @@ def set_citations(self, citations: List[Citation]) -> None: Args: citations: Citations to be included in the message. + + Note: This method schedules the citation update atomically but does not block. """ - if citations: - if not self._citations: - self._citations = [] - - curr_pos = len(self._citations) - - for citation in citations: - client_citation = ClientCitation( - type="Claim", - position=curr_pos + 1, - appearance={ - "type": "DigitalDocument", - "name": citation.title or f"Document #{curr_pos + 1}", - "abstract": CitationUtil.snippet(citation.content, 477), - }, - ) - curr_pos += 1 - self._citations.append(client_citation) + if not citations: + return + + # Build new citations outside of lock + async def update_citations(): + async with self._state_lock: + if not self._citations: + self._citations = [] + + curr_pos = len(self._citations) + + for citation in citations: + client_citation = ClientCitation( + type="Claim", + position=curr_pos + 1, + appearance={ + "type": "DigitalDocument", + "name": citation.title or f"Document #{curr_pos + 1}", + "abstract": CitationUtil.snippet(citation.content, 477), + }, + ) + curr_pos += 1 + self._citations.append(client_citation) + + # Schedule the update to run in the event loop + asyncio.create_task(update_citations()) def set_feedback_loop(self, enable_feedback_loop: bool) -> None: """ @@ -241,14 +272,63 @@ def get_message(self) -> str: """ Returns the most recently streamed message. """ - return self._message + with self._message_lock: + return self._message async def wait_for_queue(self) -> None: """ Waits for the outgoing activity queue to be empty. """ - if self._queue_sync: - await self._queue_sync + await self._queue.join() + + async with self._drain_task_lock: + drain_task = self._drain_task + + if drain_task: + await drain_task + + async def cleanup(self) -> None: + """ + Cleans up resources and cancels any running tasks. + Should be called when the StreamingResponse is no longer needed. + """ + async with self._drain_task_lock: + drain_task = self._drain_task + self._drain_task = None + + if drain_task and not drain_task.done(): + drain_task.cancel() + try: + await drain_task + except asyncio.CancelledError: + logger.debug("Queue drain task was cancelled during cleanup.") + except Exception as err: + logger.warning(f"Error while cleaning up queue drain task: {err}") + + # Clear any remaining queue items to prevent memory leaks + while not self._queue.empty(): + try: + self._queue.get_nowait() + except asyncio.QueueEmpty: + break + else: + self._queue.task_done() + + async def cancel(self) -> None: + """ + Cancels the streaming response and cleans up resources. + """ + async with self._state_lock: + self._cancelled = True + await self.cleanup() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit - ensures cleanup.""" + await self.cleanup() def _set_defaults(self, context: "TurnContext"): if context.activity.channel_id == Channels.ms_teams: @@ -263,51 +343,57 @@ def _set_defaults(self, context: "TurnContext"): self._channel_id = context.activity.channel_id - def _queue_next_chunk(self) -> None: + async def _queue_next_chunk(self) -> None: """ Queues the next chunk of text to be sent to the client. """ - # Are we already waiting to send a chunk? - if self._chunk_queued: - return - - # Queue a chunk of text to be sent - self._chunk_queued = True - - def create_activity(): - if self._ended: - # Send final message - activity = Activity( - type="message", - text=self._message or "end stream response", - attachments=self._attachments or [], - entities=[ - Entity( - type="streaminfo", - stream_type="final", - stream_sequence=self._sequence_number, - ) - ], - ) - elif self._is_streaming_channel: - # Send typing activity - activity = Activity( - type="typing", - text=self._message, - entities=[ - Entity( - type="streaminfo", - stream_type="streaming", - stream_sequence=self._sequence_number, - ) - ], - ) - else: + # Are we already waiting to send a chunk? (check atomically) + async with self._state_lock: + if self._chunk_queued: + return + self._chunk_queued = True + + # Create activity factory that captures current state + async def create_activity(): + # Capture message text under thread lock + with self._message_lock: + message = self._message + + async with self._state_lock: + if self._ended: + # Send final message + activity = Activity( + type="message", + text=message or "end stream response", + attachments=self._attachments or [], + entities=[ + Entity( + type="streaminfo", + stream_type="final", + stream_sequence=self._sequence_number, + ) + ], + ) + elif self._is_streaming_channel: + # Send typing activity + activity = Activity( + type="typing", + text=message, + entities=[ + Entity( + type="streaminfo", + stream_type="streaming", + stream_sequence=self._sequence_number, + ) + ], + ) + else: + self._chunk_queued = False + return None + + self._sequence_number += 1 self._chunk_queued = False - return None - self._sequence_number += 1 - self._chunk_queued = False - return activity + return activity self._queue_activity(create_activity) @@ -315,49 +401,67 @@ def _queue_activity(self, factory: Callable[[], Activity]) -> None: """ Queues an activity to be sent to the client. """ - self._queue.append(factory) + self._queue.put_nowait(factory) - # If there's no sync in progress, start one - # Use a lock to prevent race conditions when checking/starting the drain task - async def start_drain_if_needed(): - async with self._queue_lock: - if not self._queue_sync or self._queue_sync.done(): - self._queue_sync = asyncio.create_task(self._drain_queue()) + # Ensure a drain task is running to process the queue + async def ensure_drain_task(): + async with self._drain_task_lock: + if not self._drain_task or self._drain_task.done(): + try: + self._drain_task = asyncio.create_task(self._drain_queue()) + except Exception as err: + logger.error(f"Failed to create drain task: {err}") + self._drain_task = None + raise - # Schedule the coroutine to run - asyncio.create_task(start_drain_if_needed()) + asyncio.create_task(ensure_drain_task()) async def _drain_queue(self) -> None: """ Sends any queued activities to the client until the queue is empty. """ try: - logger.debug(f"Draining queue with {len(self._queue)} activities.") - while self._queue: - # Use lock to safely access the queue - async with self._queue_lock: - if not self._queue: + logger.debug("Draining queue with %s activities.", self._queue.qsize()) + while True: + # Check cancellation flag under lock + async with self._state_lock: + if self._cancelled: break - factory = self._queue.pop(0) + + try: + factory = self._queue.get_nowait() + except asyncio.QueueEmpty: + break - activity = factory() - if activity: - await self._send_activity(activity) + try: + activity = await factory() + # Check cancellation again before sending + async with self._state_lock: + cancelled = self._cancelled + if activity and not cancelled: + await self._send_activity(activity) + finally: + self._queue.task_done() + except asyncio.CancelledError: + logger.debug("Queue drain task was cancelled.") + raise except Exception as err: if ( "403" in str(err) and self._context.activity.channel_id == Channels.ms_teams ): logger.warning("Teams channel stopped the stream.") - self._cancelled = True + async with self._state_lock: + self._cancelled = True else: logger.error( f"Error occurred when sending activity while streaming: {err}" ) raise finally: - async with self._queue_lock: - self._queue_sync = None + # Always clean up the task reference + async with self._drain_task_lock: + self._drain_task = None async def _send_activity(self, activity: Activity) -> None: """ @@ -366,6 +470,19 @@ async def _send_activity(self, activity: Activity) -> None: Args: activity: The activity to send. """ + # Capture message under thread lock + with self._message_lock: + message = self._message + + # Capture other state snapshot under async lock + async with self._state_lock: + stream_id = self._stream_id + citations = self._citations + ended = self._ended + enable_feedback_loop = self._enable_feedback_loop + feedback_loop_type = self._feedback_loop_type + enable_generated_by_ai_label = self._enable_generated_by_ai_label + sensitivity_label = self._sensitivity_label streaminfo_entity = None @@ -384,15 +501,13 @@ async def _send_activity(self, activity: Activity) -> None: activity.entities.append(streaminfo_entity) # Set activity ID to the assigned stream ID - if self._stream_id: - activity.id = self._stream_id - streaminfo_entity.stream_id = self._stream_id + if stream_id: + activity.id = stream_id + streaminfo_entity.stream_id = stream_id - if self._citations and len(self._citations) > 0 and not self._ended: + if citations and len(citations) > 0 and not ended: # Filter out the citations unused in content. - curr_citations = CitationUtil.get_used_citations( - self._message, self._citations - ) + curr_citations = CitationUtil.get_used_citations(message, citations) if curr_citations: activity.entities.append( Entity( @@ -405,21 +520,22 @@ async def _send_activity(self, activity: Activity) -> None: ) # Add in Powered by AI feature flags - if self._ended: - if self._enable_feedback_loop and self._feedback_loop_type: + if ended: + if enable_feedback_loop and feedback_loop_type: # Add feedback loop to streaminfo entity - streaminfo_entity.feedback_loop = {"type": self._feedback_loop_type} + streaminfo_entity.feedback_loop = {"type": feedback_loop_type} else: # Add feedback loop enabled to streaminfo entity - streaminfo_entity.feedback_loop_enabled = self._enable_feedback_loop + streaminfo_entity.feedback_loop_enabled = enable_feedback_loop # Add in Generated by AI - if self._enable_generated_by_ai_label: - activity.add_ai_metadata(self._citations, self._sensitivity_label) + if enable_generated_by_ai_label: + activity.add_ai_metadata(citations, sensitivity_label) # Send activity response = await self._context.send_activity(activity) await asyncio.sleep(self._interval) - # Save assigned stream ID - if not self._stream_id and response: - self._stream_id = response.id + # Save assigned stream ID atomically + async with self._state_lock: + if not self._stream_id and response: + self._stream_id = response.id diff --git a/test_samples/app_style/streaming_agent.py b/test_samples/app_style/streaming_agent.py index e20f902d..df8d29de 100644 --- a/test_samples/app_style/streaming_agent.py +++ b/test_samples/app_style/streaming_agent.py @@ -76,13 +76,13 @@ async def on_message(context: TurnContext, state: TurnState): for i in range(5): print(f"Streaming chunk {i + 1}") - context.streaming_response.queue_text_chunk(f"part [{i + 1}] ") + await context.streaming_response.queue_text_chunk(f"part [{i + 1}] ") await asyncio.sleep(i * 0.5) - context.streaming_response.queue_text_chunk( + await context.streaming_response.queue_text_chunk( "This is the final message part. [doc1] and [doc2]" ) - context.streaming_response.set_citations( + await context.streaming_response.set_citations( [ Citation(title="Citation1", content="file", filepath="", url="file:////"), Citation( diff --git a/tests/hosting_aiohttp/test_streaming_response.py b/tests/hosting_aiohttp/test_streaming_response.py new file mode 100644 index 00000000..cc09ba87 --- /dev/null +++ b/tests/hosting_aiohttp/test_streaming_response.py @@ -0,0 +1,572 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import asyncio +import pytest +from unittest.mock import AsyncMock, Mock, MagicMock + +from microsoft_agents.activity import ( + Activity, + ActivityTypes, + ChannelAccount, + Channels, + ConversationAccount, + DeliveryModes, + Entity, + ResourceResponse, +) +from microsoft_agents.hosting.core import TurnContext, ChannelAdapter +from microsoft_agents.hosting.aiohttp.app.streaming import StreamingResponse, Citation + + +class MockAdapter(ChannelAdapter): + """Mock adapter for testing.""" + + def __init__(self): + self.sent_activities = [] + self._response_id = 1 + + async def send_activities(self, context, activities): + responses = [] + for activity in activities: + self.sent_activities.append(activity) + responses.append(ResourceResponse(id=f"response_{self._response_id}")) + self._response_id += 1 + return responses + + async def update_activity(self, context, activity): + return ResourceResponse(id=activity.id) + + async def delete_activity(self, context, reference): + pass + + +def create_test_activity( + channel_id: str = Channels.direct_line, + delivery_mode: str = DeliveryModes.normal, +) -> Activity: + """Create a test activity.""" + return Activity( + type=ActivityTypes.message, + id="test-activity-id", + text="test message", + from_property=ChannelAccount(id="user-id", name="User"), + recipient=ChannelAccount(id="bot-id", name="Bot"), + conversation=ConversationAccount(id="conversation-id"), + channel_id=channel_id, + service_url="https://test.example.com", + delivery_mode=delivery_mode, + ) + + +@pytest.fixture +def mock_adapter(): + """Create a mock adapter.""" + return MockAdapter() + + +@pytest.fixture +def turn_context(mock_adapter): + """Create a turn context for testing.""" + activity = create_test_activity() + return TurnContext(mock_adapter, activity) + + +@pytest.fixture +def streaming_context(mock_adapter): + """Create a turn context with streaming enabled (DirectLine).""" + activity = create_test_activity(channel_id=Channels.direct_line) + return TurnContext(mock_adapter, activity) + + +@pytest.fixture +def teams_context(mock_adapter): + """Create a turn context for Teams channel.""" + activity = create_test_activity(channel_id=Channels.ms_teams) + return TurnContext(mock_adapter, activity) + + +@pytest.fixture +def delivery_mode_context(mock_adapter): + """Create a turn context with delivery mode set to stream.""" + activity = create_test_activity( + channel_id="custom-channel", delivery_mode=DeliveryModes.stream + ) + return TurnContext(mock_adapter, activity) + + +class TestStreamingResponseInitialization: + """Test StreamingResponse initialization and configuration.""" + + def test_init_basic(self, turn_context): + """Test basic initialization.""" + response = StreamingResponse(turn_context) + assert response._context == turn_context + assert response._sequence_number == 1 + assert response._stream_id is None + assert response._message == "" + assert response._ended is False + assert response._cancelled is False + + def test_init_direct_line_sets_streaming_channel(self, streaming_context): + """Test DirectLine channel enables streaming.""" + response = StreamingResponse(streaming_context) + assert response._is_streaming_channel is True + assert response._interval == 0.5 + + def test_init_teams_sets_streaming_channel(self, teams_context): + """Test Teams channel enables streaming.""" + response = StreamingResponse(teams_context) + assert response._is_streaming_channel is True + assert response._interval == 1.0 + + def test_init_delivery_mode_stream(self, delivery_mode_context): + """Test delivery_mode='stream' enables streaming.""" + response = StreamingResponse(delivery_mode_context) + assert response._is_streaming_channel is True + assert response._interval == 0.1 + + +class TestStreamingResponseProperties: + """Test StreamingResponse properties.""" + + def test_stream_id_property_initial(self, turn_context): + """Test stream_id property returns None initially.""" + response = StreamingResponse(turn_context) + assert response.stream_id is None + + def test_citations_property_initial(self, turn_context): + """Test citations property returns empty list initially.""" + response = StreamingResponse(turn_context) + assert response.citations == [] + + def test_updates_sent_property_initial(self, turn_context): + """Test updates_sent property returns 0 initially.""" + response = StreamingResponse(turn_context) + assert response.updates_sent == 0 + + +class TestQueueInformativeUpdate: + """Test queue_informative_update method.""" + + @pytest.mark.asyncio + async def test_queue_informative_update_non_streaming_channel(self, turn_context): + """Test informative update is not queued for non-streaming channels.""" + # Create context with non-streaming channel + activity = create_test_activity(channel_id="non-streaming") + context = TurnContext(MockAdapter(), activity) + response = StreamingResponse(context) + + response.queue_informative_update("test message") + await asyncio.sleep(0.1) # Allow queue to process + + # Should not queue anything for non-streaming channels + assert context.adapter.sent_activities == [] + + @pytest.mark.asyncio + async def test_queue_informative_update_streaming_channel(self, streaming_context): + """Test informative update is queued for streaming channels.""" + response = StreamingResponse(streaming_context) + + response.queue_informative_update("Starting process...") + await response.wait_for_queue() + + assert len(streaming_context.adapter.sent_activities) == 1 + activity = streaming_context.adapter.sent_activities[0] + assert activity.type == "typing" + assert activity.text == "Starting process..." + assert len(activity.entities) == 1 + assert activity.entities[0].type == "streaminfo" + assert activity.entities[0].stream_type == "informative" + assert activity.entities[0].stream_sequence == 1 + + @pytest.mark.asyncio + async def test_queue_informative_update_raises_after_ended(self, streaming_context): + """Test informative update raises error after stream ended.""" + response = StreamingResponse(streaming_context) + await response.end_stream() + + with pytest.raises(RuntimeError, match="stream has already ended"): + response.queue_informative_update("test") + + +class TestQueueTextChunk: + """Test queue_text_chunk method.""" + + @pytest.mark.asyncio + async def test_queue_text_chunk_simple(self, streaming_context): + """Test queueing a simple text chunk.""" + response = StreamingResponse(streaming_context) + + await response.queue_text_chunk("Hello ") + await response.wait_for_queue() + + assert len(streaming_context.adapter.sent_activities) == 1 + activity = streaming_context.adapter.sent_activities[0] + assert activity.type == "typing" + assert activity.text == "Hello " + + @pytest.mark.asyncio + async def test_queue_text_chunk_multiple(self, streaming_context): + """Test queueing multiple text chunks.""" + response = StreamingResponse(streaming_context) + + await response.queue_text_chunk("Hello ") + await response.queue_text_chunk("World") + await response.queue_text_chunk("!") + await response.wait_for_queue() + + # Message should accumulate + assert response.get_message() == "Hello World!" + assert len(streaming_context.adapter.sent_activities) >= 1 + + @pytest.mark.asyncio + async def test_queue_text_chunk_cancelled(self, streaming_context): + """Test queueing text chunk after cancellation.""" + response = StreamingResponse(streaming_context) + await response.cancel() + + # Should not raise, just return silently + await response.queue_text_chunk("test") + await asyncio.sleep(0.1) + + @pytest.mark.asyncio + async def test_queue_text_chunk_raises_after_ended(self, streaming_context): + """Test queueing text chunk after stream ended raises error.""" + response = StreamingResponse(streaming_context) + await response.queue_text_chunk("test") + await response.end_stream() + + with pytest.raises(RuntimeError, match="stream has already ended"): + await response.queue_text_chunk("more") + + @pytest.mark.asyncio + async def test_queue_text_chunk_updates_sequence(self, streaming_context): + """Test that text chunks increment sequence number.""" + response = StreamingResponse(streaming_context) + + await response.queue_text_chunk("chunk1") + await response.wait_for_queue() # Wait for first chunk to be sent + + await response.queue_text_chunk("chunk2") + await response.wait_for_queue() # Wait for second chunk to be sent + + # Updates_sent should reflect the number of activities sent + # Each chunk should result in at least one activity + assert response.updates_sent >= 2 + + +class TestEndStream: + """Test end_stream method.""" + + @pytest.mark.asyncio + async def test_end_stream_sends_final_message(self, streaming_context): + """Test end_stream sends a final message activity.""" + response = StreamingResponse(streaming_context) + + await response.queue_text_chunk("Hello World") + await response.end_stream() + + # Find the final message + final_activities = [ + a for a in streaming_context.adapter.sent_activities if a.type == "message" + ] + assert len(final_activities) == 1 + activity = final_activities[0] + assert activity.text == "Hello World" + assert any( + e.type == "streaminfo" and e.stream_type == "final" + for e in activity.entities + ) + + @pytest.mark.asyncio + async def test_end_stream_twice_raises_error(self, streaming_context): + """Test calling end_stream twice raises error.""" + response = StreamingResponse(streaming_context) + + await response.end_stream() + + with pytest.raises(RuntimeError, match="stream has already ended"): + await response.end_stream() + + @pytest.mark.asyncio + async def test_end_stream_with_empty_message(self, streaming_context): + """Test end_stream with no text queued.""" + response = StreamingResponse(streaming_context) + await response.end_stream() + + final_activities = [ + a for a in streaming_context.adapter.sent_activities if a.type == "message" + ] + assert len(final_activities) == 1 + assert final_activities[0].text == "end stream response" + + +class TestSetAttachments: + """Test set_attachments method.""" + + @pytest.mark.asyncio + async def test_set_attachments(self, streaming_context): + """Test setting attachments on final message.""" + from microsoft_agents.activity import Attachment + + response = StreamingResponse(streaming_context) + + attachments = [ + Attachment(content_type="text/plain", content="test attachment") + ] + response.set_attachments(attachments) + await response.end_stream() + + final_activities = [ + a for a in streaming_context.adapter.sent_activities if a.type == "message" + ] + assert len(final_activities) == 1 + assert final_activities[0].attachments == attachments + + +class TestSetCitations: + """Test set_citations method.""" + + @pytest.mark.asyncio + async def test_set_citations(self, streaming_context): + """Test setting citations.""" + response = StreamingResponse(streaming_context) + + citations = [ + Citation( + title="Source 1", + content="Content 1", + filepath="file1.txt", + url="http://example.com/1", + ), + Citation( + title="Source 2", + content="Content 2", + filepath="file2.txt", + url="http://example.com/2", + ), + ] + + await response.set_citations(citations) + + assert response.citations is not None + assert len(response.citations) == 2 + assert response.citations[0].position == 1 + assert response.citations[1].position == 2 + + @pytest.mark.asyncio + async def test_set_citations_multiple_calls(self, streaming_context): + """Test setting citations multiple times appends.""" + response = StreamingResponse(streaming_context) + + citations1 = [ + Citation( + title="Source 1", + content="Content 1", + filepath="file1.txt", + url="http://example.com/1", + ) + ] + await response.set_citations(citations1) + + citations2 = [ + Citation( + title="Source 2", + content="Content 2", + filepath="file2.txt", + url="http://example.com/2", + ) + ] + await response.set_citations(citations2) + + assert len(response.citations) == 2 + + +class TestFeedbackAndLabels: + """Test feedback loop and AI label settings.""" + + def test_set_feedback_loop(self, turn_context): + """Test setting feedback loop.""" + response = StreamingResponse(turn_context) + response.set_feedback_loop(True) + assert response._enable_feedback_loop is True + + def test_set_feedback_loop_type(self, turn_context): + """Test setting feedback loop type.""" + response = StreamingResponse(turn_context) + response.set_feedback_loop_type("custom") + assert response._feedback_loop_type == "custom" + + def test_set_generated_by_ai_label(self, turn_context): + """Test setting generated by AI label.""" + response = StreamingResponse(turn_context) + response.set_generated_by_ai_label(True) + assert response._enable_generated_by_ai_label is True + + +class TestGetMessage: + """Test get_message method.""" + + @pytest.mark.asyncio + async def test_get_message_initial(self, turn_context): + """Test get_message returns empty string initially.""" + response = StreamingResponse(turn_context) + assert response.get_message() == "" + + @pytest.mark.asyncio + async def test_get_message_after_chunks(self, streaming_context): + """Test get_message returns accumulated text.""" + response = StreamingResponse(streaming_context) + + await response.queue_text_chunk("Hello ") + await response.queue_text_chunk("World") + + assert response.get_message() == "Hello World" + + +class TestWaitForQueue: + """Test wait_for_queue method.""" + + @pytest.mark.asyncio + async def test_wait_for_queue_empty(self, turn_context): + """Test wait_for_queue returns immediately when queue is empty.""" + response = StreamingResponse(turn_context) + await response.wait_for_queue() # Should not hang + + @pytest.mark.asyncio + async def test_wait_for_queue_processes_all(self, streaming_context): + """Test wait_for_queue waits for all items to process.""" + response = StreamingResponse(streaming_context) + + await response.queue_text_chunk("chunk1") + await response.wait_for_queue() # Wait for chunk to be sent + + await response.queue_text_chunk("chunk2") + await response.wait_for_queue() # Wait for chunk to be sent + + await response.queue_text_chunk("chunk3") + await response.wait_for_queue() # Wait for chunk to be sent + await response.wait_for_queue() + + # All activities should be sent + assert len(streaming_context.adapter.sent_activities) >= 3 + + +class TestCleanupAndCancel: + """Test cleanup and cancel methods.""" + + @pytest.mark.asyncio + async def test_cleanup(self, streaming_context): + """Test cleanup method.""" + response = StreamingResponse(streaming_context) + + await response.queue_text_chunk("test") + await response.cleanup() + + # Queue should be empty after cleanup + assert response._queue.empty() + + @pytest.mark.asyncio + async def test_cancel(self, streaming_context): + """Test cancel method.""" + response = StreamingResponse(streaming_context) + + await response.queue_text_chunk("test") + await response.cancel() + + # Cancelled flag should be set + assert response._cancelled is True + + @pytest.mark.asyncio + async def test_cancel_stops_processing(self, streaming_context): + """Test cancel stops further processing.""" + response = StreamingResponse(streaming_context) + + await response.queue_text_chunk("chunk1") + await response.cancel() + + # Subsequent chunks should not raise, just be ignored + await response.queue_text_chunk("chunk2") + await asyncio.sleep(0.1) + + +class TestContextManager: + """Test async context manager protocol.""" + + @pytest.mark.asyncio + async def test_context_manager_cleanup(self, streaming_context): + """Test context manager calls cleanup on exit.""" + async with StreamingResponse(streaming_context) as response: + await response.queue_text_chunk("test") + + # After exiting context, cleanup should have been called + assert response._queue.empty() + + @pytest.mark.asyncio + async def test_context_manager_with_exception(self, streaming_context): + """Test context manager cleans up even with exception.""" + try: + async with StreamingResponse(streaming_context) as response: + await response.queue_text_chunk("test") + raise ValueError("Test exception") + except ValueError: + pass + + # Cleanup should still have been called + assert response._queue.empty() + + +class TestStreamIdAssignment: + """Test stream ID assignment.""" + + @pytest.mark.asyncio + async def test_stream_id_assigned_after_first_send(self, streaming_context): + """Test stream_id is assigned after first activity is sent.""" + response = StreamingResponse(streaming_context) + + assert response.stream_id is None + + await response.queue_text_chunk("test") + await response.wait_for_queue() + + # Stream ID should be assigned after sending + assert response.stream_id is not None + assert response.stream_id.startswith("response_") + + @pytest.mark.asyncio + async def test_stream_id_consistent_across_activities(self, streaming_context): + """Test all activities in stream use same stream ID.""" + response = StreamingResponse(streaming_context) + + await response.queue_text_chunk("chunk1") + await response.wait_for_queue() + + first_stream_id = response.stream_id + + await response.queue_text_chunk("chunk2") + await response.wait_for_queue() + + assert response.stream_id == first_stream_id + + +class TestSequenceNumbers: + """Test sequence number management.""" + + @pytest.mark.asyncio + async def test_sequence_numbers_increment(self, streaming_context): + """Test sequence numbers increment correctly.""" + response = StreamingResponse(streaming_context) + + response.queue_informative_update("info1") + await response.queue_text_chunk("chunk1") + await response.queue_text_chunk("chunk2") + await response.wait_for_queue() + + # Each activity should have increasing sequence numbers + for i, activity in enumerate(streaming_context.adapter.sent_activities): + stream_entity = next( + (e for e in activity.entities if e.type == "streaminfo"), None + ) + assert stream_entity is not None + assert stream_entity.stream_sequence == i + 1 From fcfa96d05f3c0f8a64043c584fa60eaa89977ff6 Mon Sep 17 00:00:00 2001 From: Chris Mullins Date: Tue, 21 Oct 2025 13:44:04 -0700 Subject: [PATCH 4/4] Refactor StreamingResponse tests to remove unnecessary await on queue_text_chunk and set_citations methods for improved performance --- test_samples/app_style/streaming_agent.py | 6 +- .../test_streaming_response.py | 80 ++++++++++++------- 2 files changed, 54 insertions(+), 32 deletions(-) diff --git a/test_samples/app_style/streaming_agent.py b/test_samples/app_style/streaming_agent.py index df8d29de..e20f902d 100644 --- a/test_samples/app_style/streaming_agent.py +++ b/test_samples/app_style/streaming_agent.py @@ -76,13 +76,13 @@ async def on_message(context: TurnContext, state: TurnState): for i in range(5): print(f"Streaming chunk {i + 1}") - await context.streaming_response.queue_text_chunk(f"part [{i + 1}] ") + context.streaming_response.queue_text_chunk(f"part [{i + 1}] ") await asyncio.sleep(i * 0.5) - await context.streaming_response.queue_text_chunk( + context.streaming_response.queue_text_chunk( "This is the final message part. [doc1] and [doc2]" ) - await context.streaming_response.set_citations( + context.streaming_response.set_citations( [ Citation(title="Citation1", content="file", filepath="", url="file:////"), Citation( diff --git a/tests/hosting_aiohttp/test_streaming_response.py b/tests/hosting_aiohttp/test_streaming_response.py index cc09ba87..e423d173 100644 --- a/tests/hosting_aiohttp/test_streaming_response.py +++ b/tests/hosting_aiohttp/test_streaming_response.py @@ -198,7 +198,8 @@ async def test_queue_text_chunk_simple(self, streaming_context): """Test queueing a simple text chunk.""" response = StreamingResponse(streaming_context) - await response.queue_text_chunk("Hello ") + response.queue_text_chunk("Hello ") + await asyncio.sleep(0.01) # Allow background task to start await response.wait_for_queue() assert len(streaming_context.adapter.sent_activities) == 1 @@ -211,9 +212,10 @@ async def test_queue_text_chunk_multiple(self, streaming_context): """Test queueing multiple text chunks.""" response = StreamingResponse(streaming_context) - await response.queue_text_chunk("Hello ") - await response.queue_text_chunk("World") - await response.queue_text_chunk("!") + response.queue_text_chunk("Hello ") + response.queue_text_chunk("World") + response.queue_text_chunk("!") + await asyncio.sleep(0.01) # Allow background tasks to start await response.wait_for_queue() # Message should accumulate @@ -227,28 +229,31 @@ async def test_queue_text_chunk_cancelled(self, streaming_context): await response.cancel() # Should not raise, just return silently - await response.queue_text_chunk("test") + response.queue_text_chunk("test") await asyncio.sleep(0.1) @pytest.mark.asyncio async def test_queue_text_chunk_raises_after_ended(self, streaming_context): """Test queueing text chunk after stream ended raises error.""" response = StreamingResponse(streaming_context) - await response.queue_text_chunk("test") + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start await response.end_stream() with pytest.raises(RuntimeError, match="stream has already ended"): - await response.queue_text_chunk("more") + response.queue_text_chunk("more") @pytest.mark.asyncio async def test_queue_text_chunk_updates_sequence(self, streaming_context): """Test that text chunks increment sequence number.""" response = StreamingResponse(streaming_context) - await response.queue_text_chunk("chunk1") + response.queue_text_chunk("chunk1") + await asyncio.sleep(0.01) # Allow background task to start await response.wait_for_queue() # Wait for first chunk to be sent - await response.queue_text_chunk("chunk2") + response.queue_text_chunk("chunk2") + await asyncio.sleep(0.01) # Allow background task to start await response.wait_for_queue() # Wait for second chunk to be sent # Updates_sent should reflect the number of activities sent @@ -264,7 +269,8 @@ async def test_end_stream_sends_final_message(self, streaming_context): """Test end_stream sends a final message activity.""" response = StreamingResponse(streaming_context) - await response.queue_text_chunk("Hello World") + response.queue_text_chunk("Hello World") + await asyncio.sleep(0.01) # Allow background task to start await response.end_stream() # Find the final message @@ -348,7 +354,8 @@ async def test_set_citations(self, streaming_context): ), ] - await response.set_citations(citations) + response.set_citations(citations) + await asyncio.sleep(0.05) # Allow async task to complete assert response.citations is not None assert len(response.citations) == 2 @@ -368,7 +375,8 @@ async def test_set_citations_multiple_calls(self, streaming_context): url="http://example.com/1", ) ] - await response.set_citations(citations1) + response.set_citations(citations1) + await asyncio.sleep(0.05) # Allow async task to complete citations2 = [ Citation( @@ -378,7 +386,8 @@ async def test_set_citations_multiple_calls(self, streaming_context): url="http://example.com/2", ) ] - await response.set_citations(citations2) + response.set_citations(citations2) + await asyncio.sleep(0.05) # Allow async task to complete assert len(response.citations) == 2 @@ -419,9 +428,10 @@ async def test_get_message_after_chunks(self, streaming_context): """Test get_message returns accumulated text.""" response = StreamingResponse(streaming_context) - await response.queue_text_chunk("Hello ") - await response.queue_text_chunk("World") + response.queue_text_chunk("Hello ") + response.queue_text_chunk("World") + # Message is updated synchronously, no need to wait assert response.get_message() == "Hello World" @@ -439,13 +449,16 @@ async def test_wait_for_queue_processes_all(self, streaming_context): """Test wait_for_queue waits for all items to process.""" response = StreamingResponse(streaming_context) - await response.queue_text_chunk("chunk1") + response.queue_text_chunk("chunk1") + await asyncio.sleep(0.01) # Allow background task to start await response.wait_for_queue() # Wait for chunk to be sent - await response.queue_text_chunk("chunk2") + response.queue_text_chunk("chunk2") + await asyncio.sleep(0.01) # Allow background task to start await response.wait_for_queue() # Wait for chunk to be sent - await response.queue_text_chunk("chunk3") + response.queue_text_chunk("chunk3") + await asyncio.sleep(0.01) # Allow background task to start await response.wait_for_queue() # Wait for chunk to be sent await response.wait_for_queue() @@ -461,7 +474,8 @@ async def test_cleanup(self, streaming_context): """Test cleanup method.""" response = StreamingResponse(streaming_context) - await response.queue_text_chunk("test") + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start await response.cleanup() # Queue should be empty after cleanup @@ -472,7 +486,8 @@ async def test_cancel(self, streaming_context): """Test cancel method.""" response = StreamingResponse(streaming_context) - await response.queue_text_chunk("test") + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start await response.cancel() # Cancelled flag should be set @@ -483,11 +498,12 @@ async def test_cancel_stops_processing(self, streaming_context): """Test cancel stops further processing.""" response = StreamingResponse(streaming_context) - await response.queue_text_chunk("chunk1") + response.queue_text_chunk("chunk1") + await asyncio.sleep(0.01) # Allow background task to start await response.cancel() # Subsequent chunks should not raise, just be ignored - await response.queue_text_chunk("chunk2") + response.queue_text_chunk("chunk2") await asyncio.sleep(0.1) @@ -498,7 +514,8 @@ class TestContextManager: async def test_context_manager_cleanup(self, streaming_context): """Test context manager calls cleanup on exit.""" async with StreamingResponse(streaming_context) as response: - await response.queue_text_chunk("test") + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start # After exiting context, cleanup should have been called assert response._queue.empty() @@ -508,7 +525,8 @@ async def test_context_manager_with_exception(self, streaming_context): """Test context manager cleans up even with exception.""" try: async with StreamingResponse(streaming_context) as response: - await response.queue_text_chunk("test") + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start raise ValueError("Test exception") except ValueError: pass @@ -527,7 +545,8 @@ async def test_stream_id_assigned_after_first_send(self, streaming_context): assert response.stream_id is None - await response.queue_text_chunk("test") + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start await response.wait_for_queue() # Stream ID should be assigned after sending @@ -539,12 +558,14 @@ async def test_stream_id_consistent_across_activities(self, streaming_context): """Test all activities in stream use same stream ID.""" response = StreamingResponse(streaming_context) - await response.queue_text_chunk("chunk1") + response.queue_text_chunk("chunk1") + await asyncio.sleep(0.01) # Allow background task to start await response.wait_for_queue() first_stream_id = response.stream_id - await response.queue_text_chunk("chunk2") + response.queue_text_chunk("chunk2") + await asyncio.sleep(0.01) # Allow background task to start await response.wait_for_queue() assert response.stream_id == first_stream_id @@ -559,8 +580,9 @@ async def test_sequence_numbers_increment(self, streaming_context): response = StreamingResponse(streaming_context) response.queue_informative_update("info1") - await response.queue_text_chunk("chunk1") - await response.queue_text_chunk("chunk2") + response.queue_text_chunk("chunk1") + response.queue_text_chunk("chunk2") + await asyncio.sleep(0.01) # Allow background tasks to start await response.wait_for_queue() # Each activity should have increasing sequence numbers