diff --git a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py index 4c1cff5b..3441019c 100644 --- a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py +++ b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py @@ -3,6 +3,7 @@ import asyncio import logging +import threading from typing import List, Optional, Callable, Literal, TYPE_CHECKING from dataclasses import dataclass @@ -53,8 +54,14 @@ def __init__(self, context: "TurnContext"): self._cancelled = False # Queue for outgoing activities - self._queue: List[Callable[[], Activity]] = [] - self._queue_sync: Optional[asyncio.Task] = None + self._queue: asyncio.Queue[Callable[[], Activity]] = asyncio.Queue() + self._drain_task: Optional[asyncio.Task] = None + self._drain_task_lock = asyncio.Lock() + + # State lock to protect shared mutable state (for async operations) + self._state_lock = asyncio.Lock() + # Sync lock for message text updates (can be called from sync context) + self._message_lock = threading.Lock() self._chunk_queued = False # Powered by AI feature flags @@ -75,17 +82,24 @@ def stream_id(self) -> Optional[str]: """ Gets the stream ID of the current response. Assigned after the initial update is sent. + Note: Returns a snapshot; may be stale if called during concurrent updates. """ return self._stream_id @property def citations(self) -> Optional[List[ClientCitation]]: - """Gets the citations of the current response.""" + """ + Gets the citations of the current response. + Note: Returns reference to internal list; do not modify directly. + """ return self._citations @property def updates_sent(self) -> int: - """Gets the number of updates sent for the stream.""" + """ + Gets the number of updates sent for the stream. + Note: Returns a snapshot; may be stale if called during concurrent updates. + """ return self._sequence_number - 1 def queue_informative_update(self, text: str) -> None: @@ -101,8 +115,12 @@ def queue_informative_update(self, text: str) -> None: if self._ended: raise RuntimeError("The stream has already ended.") - # Queue a typing activity - def create_activity(): + # Queue a typing activity - capture sequence number atomically + async def create_activity_async(): + async with self._state_lock: + seq_num = self._sequence_number + self._sequence_number += 1 + activity = Activity( type="typing", text=text, @@ -110,14 +128,13 @@ def create_activity(): Entity( type="streaminfo", stream_type="informative", - stream_sequence=self._sequence_number, + stream_sequence=seq_num, ) ], ) - self._sequence_number += 1 return activity - self._queue_activity(create_activity) + self._queue_activity(create_activity_async) def queue_text_chunk( self, text: str, citations: Optional[List[Citation]] = None @@ -132,33 +149,38 @@ def queue_text_chunk( text: Partial text of the message to send. citations: Citations to be included in the message. """ - if self._cancelled: - return - if self._ended: - raise RuntimeError("The stream has already ended.") - - # Update full message text - self._message += text + # Update message text synchronously under thread lock + with self._message_lock: + if self._cancelled: + return + if self._ended: + raise RuntimeError("The stream has already ended.") - # If there are citations, modify the content so that the sources are numbers instead of [doc1], [doc2], etc. - self._message = CitationUtil.format_citations_response(self._message) + # Update full message text atomically + self._message += text + # If there are citations, modify the content so that the sources are numbers instead of [doc1], [doc2], etc. + self._message = CitationUtil.format_citations_response(self._message) - # Queue the next chunk - self._queue_next_chunk() + # Schedule the async queueing work in background + asyncio.create_task(self._queue_next_chunk()) async def end_stream(self) -> None: """ Ends the stream by sending the final message to the client. """ - if self._ended: - raise RuntimeError("The stream has already ended.") - - # Queue final message - self._ended = True - self._queue_next_chunk() + async with self._state_lock: + if self._ended: + raise RuntimeError("The stream has already ended.") + # Queue final message + self._ended = True + + await self._queue_next_chunk() # Wait for the queue to drain await self.wait_for_queue() + + # Clean up any remaining tasks + await self.cleanup() def set_attachments(self, attachments: List[Attachment]) -> None: """ @@ -184,25 +206,35 @@ def set_citations(self, citations: List[Citation]) -> None: Args: citations: Citations to be included in the message. + + Note: This method schedules the citation update atomically but does not block. """ - if citations: - if not self._citations: - self._citations = [] - - curr_pos = len(self._citations) - - for citation in citations: - client_citation = ClientCitation( - type="Claim", - position=curr_pos + 1, - appearance={ - "type": "DigitalDocument", - "name": citation.title or f"Document #{curr_pos + 1}", - "abstract": CitationUtil.snippet(citation.content, 477), - }, - ) - curr_pos += 1 - self._citations.append(client_citation) + if not citations: + return + + # Build new citations outside of lock + async def update_citations(): + async with self._state_lock: + if not self._citations: + self._citations = [] + + curr_pos = len(self._citations) + + for citation in citations: + client_citation = ClientCitation( + type="Claim", + position=curr_pos + 1, + appearance={ + "type": "DigitalDocument", + "name": citation.title or f"Document #{curr_pos + 1}", + "abstract": CitationUtil.snippet(citation.content, 477), + }, + ) + curr_pos += 1 + self._citations.append(client_citation) + + # Schedule the update to run in the event loop + asyncio.create_task(update_citations()) def set_feedback_loop(self, enable_feedback_loop: bool) -> None: """ @@ -240,14 +272,63 @@ def get_message(self) -> str: """ Returns the most recently streamed message. """ - return self._message + with self._message_lock: + return self._message async def wait_for_queue(self) -> None: """ Waits for the outgoing activity queue to be empty. """ - if self._queue_sync: - await self._queue_sync + await self._queue.join() + + async with self._drain_task_lock: + drain_task = self._drain_task + + if drain_task: + await drain_task + + async def cleanup(self) -> None: + """ + Cleans up resources and cancels any running tasks. + Should be called when the StreamingResponse is no longer needed. + """ + async with self._drain_task_lock: + drain_task = self._drain_task + self._drain_task = None + + if drain_task and not drain_task.done(): + drain_task.cancel() + try: + await drain_task + except asyncio.CancelledError: + logger.debug("Queue drain task was cancelled during cleanup.") + except Exception as err: + logger.warning(f"Error while cleaning up queue drain task: {err}") + + # Clear any remaining queue items to prevent memory leaks + while not self._queue.empty(): + try: + self._queue.get_nowait() + except asyncio.QueueEmpty: + break + else: + self._queue.task_done() + + async def cancel(self) -> None: + """ + Cancels the streaming response and cleans up resources. + """ + async with self._state_lock: + self._cancelled = True + await self.cleanup() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit - ensures cleanup.""" + await self.cleanup() def _set_defaults(self, context: "TurnContext"): if context.activity.channel_id == Channels.ms_teams: @@ -262,50 +343,57 @@ def _set_defaults(self, context: "TurnContext"): self._channel_id = context.activity.channel_id - def _queue_next_chunk(self) -> None: + async def _queue_next_chunk(self) -> None: """ Queues the next chunk of text to be sent to the client. """ - # Are we already waiting to send a chunk? - if self._chunk_queued: - return - - # Queue a chunk of text to be sent - self._chunk_queued = True - - def create_activity(): - self._chunk_queued = False - if self._ended: - # Send final message - activity = Activity( - type="message", - text=self._message or "end stream response", - attachments=self._attachments or [], - entities=[ - Entity( - type="streaminfo", - stream_type="final", - stream_sequence=self._sequence_number, - ) - ], - ) - elif self._is_streaming_channel: - # Send typing activity - activity = Activity( - type="typing", - text=self._message, - entities=[ - Entity( - type="streaminfo", - stream_type="streaming", - stream_sequence=self._sequence_number, - ) - ], - ) - else: + # Are we already waiting to send a chunk? (check atomically) + async with self._state_lock: + if self._chunk_queued: return - self._sequence_number += 1 - return activity + self._chunk_queued = True + + # Create activity factory that captures current state + async def create_activity(): + # Capture message text under thread lock + with self._message_lock: + message = self._message + + async with self._state_lock: + if self._ended: + # Send final message + activity = Activity( + type="message", + text=message or "end stream response", + attachments=self._attachments or [], + entities=[ + Entity( + type="streaminfo", + stream_type="final", + stream_sequence=self._sequence_number, + ) + ], + ) + elif self._is_streaming_channel: + # Send typing activity + activity = Activity( + type="typing", + text=message, + entities=[ + Entity( + type="streaminfo", + stream_type="streaming", + stream_sequence=self._sequence_number, + ) + ], + ) + else: + self._chunk_queued = False + return None + + self._sequence_number += 1 + self._chunk_queued = False + return activity self._queue_activity(create_activity) @@ -313,37 +401,67 @@ def _queue_activity(self, factory: Callable[[], Activity]) -> None: """ Queues an activity to be sent to the client. """ - self._queue.append(factory) + self._queue.put_nowait(factory) + + # Ensure a drain task is running to process the queue + async def ensure_drain_task(): + async with self._drain_task_lock: + if not self._drain_task or self._drain_task.done(): + try: + self._drain_task = asyncio.create_task(self._drain_queue()) + except Exception as err: + logger.error(f"Failed to create drain task: {err}") + self._drain_task = None + raise - # If there's no sync in progress, start one - if not self._queue_sync: - self._queue_sync = asyncio.create_task(self._drain_queue()) + asyncio.create_task(ensure_drain_task()) async def _drain_queue(self) -> None: """ Sends any queued activities to the client until the queue is empty. """ try: - logger.debug(f"Draining queue with {len(self._queue)} activities.") - while self._queue: - factory = self._queue.pop(0) - activity = factory() - if activity: - await self._send_activity(activity) + logger.debug("Draining queue with %s activities.", self._queue.qsize()) + while True: + # Check cancellation flag under lock + async with self._state_lock: + if self._cancelled: + break + + try: + factory = self._queue.get_nowait() + except asyncio.QueueEmpty: + break + + try: + activity = await factory() + # Check cancellation again before sending + async with self._state_lock: + cancelled = self._cancelled + if activity and not cancelled: + await self._send_activity(activity) + finally: + self._queue.task_done() + except asyncio.CancelledError: + logger.debug("Queue drain task was cancelled.") + raise except Exception as err: if ( "403" in str(err) and self._context.activity.channel_id == Channels.ms_teams ): logger.warning("Teams channel stopped the stream.") - self._cancelled = True + async with self._state_lock: + self._cancelled = True else: logger.error( f"Error occurred when sending activity while streaming: {err}" ) raise finally: - self._queue_sync = None + # Always clean up the task reference + async with self._drain_task_lock: + self._drain_task = None async def _send_activity(self, activity: Activity) -> None: """ @@ -352,6 +470,19 @@ async def _send_activity(self, activity: Activity) -> None: Args: activity: The activity to send. """ + # Capture message under thread lock + with self._message_lock: + message = self._message + + # Capture other state snapshot under async lock + async with self._state_lock: + stream_id = self._stream_id + citations = self._citations + ended = self._ended + enable_feedback_loop = self._enable_feedback_loop + feedback_loop_type = self._feedback_loop_type + enable_generated_by_ai_label = self._enable_generated_by_ai_label + sensitivity_label = self._sensitivity_label streaminfo_entity = None @@ -370,15 +501,13 @@ async def _send_activity(self, activity: Activity) -> None: activity.entities.append(streaminfo_entity) # Set activity ID to the assigned stream ID - if self._stream_id: - activity.id = self._stream_id - streaminfo_entity.stream_id = self._stream_id + if stream_id: + activity.id = stream_id + streaminfo_entity.stream_id = stream_id - if self._citations and len(self._citations) > 0 and not self._ended: + if citations and len(citations) > 0 and not ended: # Filter out the citations unused in content. - curr_citations = CitationUtil.get_used_citations( - self._message, self._citations - ) + curr_citations = CitationUtil.get_used_citations(message, citations) if curr_citations: activity.entities.append( Entity( @@ -391,21 +520,22 @@ async def _send_activity(self, activity: Activity) -> None: ) # Add in Powered by AI feature flags - if self._ended: - if self._enable_feedback_loop and self._feedback_loop_type: + if ended: + if enable_feedback_loop and feedback_loop_type: # Add feedback loop to streaminfo entity - streaminfo_entity.feedback_loop = {"type": self._feedback_loop_type} + streaminfo_entity.feedback_loop = {"type": feedback_loop_type} else: # Add feedback loop enabled to streaminfo entity - streaminfo_entity.feedback_loop_enabled = self._enable_feedback_loop + streaminfo_entity.feedback_loop_enabled = enable_feedback_loop # Add in Generated by AI - if self._enable_generated_by_ai_label: - activity.add_ai_metadata(self._citations, self._sensitivity_label) + if enable_generated_by_ai_label: + activity.add_ai_metadata(citations, sensitivity_label) # Send activity response = await self._context.send_activity(activity) await asyncio.sleep(self._interval) - # Save assigned stream ID - if not self._stream_id and response: - self._stream_id = response.id + # Save assigned stream ID atomically + async with self._state_lock: + if not self._stream_id and response: + self._stream_id = response.id diff --git a/tests/hosting_aiohttp/test_streaming_response.py b/tests/hosting_aiohttp/test_streaming_response.py new file mode 100644 index 00000000..e423d173 --- /dev/null +++ b/tests/hosting_aiohttp/test_streaming_response.py @@ -0,0 +1,594 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import asyncio +import pytest +from unittest.mock import AsyncMock, Mock, MagicMock + +from microsoft_agents.activity import ( + Activity, + ActivityTypes, + ChannelAccount, + Channels, + ConversationAccount, + DeliveryModes, + Entity, + ResourceResponse, +) +from microsoft_agents.hosting.core import TurnContext, ChannelAdapter +from microsoft_agents.hosting.aiohttp.app.streaming import StreamingResponse, Citation + + +class MockAdapter(ChannelAdapter): + """Mock adapter for testing.""" + + def __init__(self): + self.sent_activities = [] + self._response_id = 1 + + async def send_activities(self, context, activities): + responses = [] + for activity in activities: + self.sent_activities.append(activity) + responses.append(ResourceResponse(id=f"response_{self._response_id}")) + self._response_id += 1 + return responses + + async def update_activity(self, context, activity): + return ResourceResponse(id=activity.id) + + async def delete_activity(self, context, reference): + pass + + +def create_test_activity( + channel_id: str = Channels.direct_line, + delivery_mode: str = DeliveryModes.normal, +) -> Activity: + """Create a test activity.""" + return Activity( + type=ActivityTypes.message, + id="test-activity-id", + text="test message", + from_property=ChannelAccount(id="user-id", name="User"), + recipient=ChannelAccount(id="bot-id", name="Bot"), + conversation=ConversationAccount(id="conversation-id"), + channel_id=channel_id, + service_url="https://test.example.com", + delivery_mode=delivery_mode, + ) + + +@pytest.fixture +def mock_adapter(): + """Create a mock adapter.""" + return MockAdapter() + + +@pytest.fixture +def turn_context(mock_adapter): + """Create a turn context for testing.""" + activity = create_test_activity() + return TurnContext(mock_adapter, activity) + + +@pytest.fixture +def streaming_context(mock_adapter): + """Create a turn context with streaming enabled (DirectLine).""" + activity = create_test_activity(channel_id=Channels.direct_line) + return TurnContext(mock_adapter, activity) + + +@pytest.fixture +def teams_context(mock_adapter): + """Create a turn context for Teams channel.""" + activity = create_test_activity(channel_id=Channels.ms_teams) + return TurnContext(mock_adapter, activity) + + +@pytest.fixture +def delivery_mode_context(mock_adapter): + """Create a turn context with delivery mode set to stream.""" + activity = create_test_activity( + channel_id="custom-channel", delivery_mode=DeliveryModes.stream + ) + return TurnContext(mock_adapter, activity) + + +class TestStreamingResponseInitialization: + """Test StreamingResponse initialization and configuration.""" + + def test_init_basic(self, turn_context): + """Test basic initialization.""" + response = StreamingResponse(turn_context) + assert response._context == turn_context + assert response._sequence_number == 1 + assert response._stream_id is None + assert response._message == "" + assert response._ended is False + assert response._cancelled is False + + def test_init_direct_line_sets_streaming_channel(self, streaming_context): + """Test DirectLine channel enables streaming.""" + response = StreamingResponse(streaming_context) + assert response._is_streaming_channel is True + assert response._interval == 0.5 + + def test_init_teams_sets_streaming_channel(self, teams_context): + """Test Teams channel enables streaming.""" + response = StreamingResponse(teams_context) + assert response._is_streaming_channel is True + assert response._interval == 1.0 + + def test_init_delivery_mode_stream(self, delivery_mode_context): + """Test delivery_mode='stream' enables streaming.""" + response = StreamingResponse(delivery_mode_context) + assert response._is_streaming_channel is True + assert response._interval == 0.1 + + +class TestStreamingResponseProperties: + """Test StreamingResponse properties.""" + + def test_stream_id_property_initial(self, turn_context): + """Test stream_id property returns None initially.""" + response = StreamingResponse(turn_context) + assert response.stream_id is None + + def test_citations_property_initial(self, turn_context): + """Test citations property returns empty list initially.""" + response = StreamingResponse(turn_context) + assert response.citations == [] + + def test_updates_sent_property_initial(self, turn_context): + """Test updates_sent property returns 0 initially.""" + response = StreamingResponse(turn_context) + assert response.updates_sent == 0 + + +class TestQueueInformativeUpdate: + """Test queue_informative_update method.""" + + @pytest.mark.asyncio + async def test_queue_informative_update_non_streaming_channel(self, turn_context): + """Test informative update is not queued for non-streaming channels.""" + # Create context with non-streaming channel + activity = create_test_activity(channel_id="non-streaming") + context = TurnContext(MockAdapter(), activity) + response = StreamingResponse(context) + + response.queue_informative_update("test message") + await asyncio.sleep(0.1) # Allow queue to process + + # Should not queue anything for non-streaming channels + assert context.adapter.sent_activities == [] + + @pytest.mark.asyncio + async def test_queue_informative_update_streaming_channel(self, streaming_context): + """Test informative update is queued for streaming channels.""" + response = StreamingResponse(streaming_context) + + response.queue_informative_update("Starting process...") + await response.wait_for_queue() + + assert len(streaming_context.adapter.sent_activities) == 1 + activity = streaming_context.adapter.sent_activities[0] + assert activity.type == "typing" + assert activity.text == "Starting process..." + assert len(activity.entities) == 1 + assert activity.entities[0].type == "streaminfo" + assert activity.entities[0].stream_type == "informative" + assert activity.entities[0].stream_sequence == 1 + + @pytest.mark.asyncio + async def test_queue_informative_update_raises_after_ended(self, streaming_context): + """Test informative update raises error after stream ended.""" + response = StreamingResponse(streaming_context) + await response.end_stream() + + with pytest.raises(RuntimeError, match="stream has already ended"): + response.queue_informative_update("test") + + +class TestQueueTextChunk: + """Test queue_text_chunk method.""" + + @pytest.mark.asyncio + async def test_queue_text_chunk_simple(self, streaming_context): + """Test queueing a simple text chunk.""" + response = StreamingResponse(streaming_context) + + response.queue_text_chunk("Hello ") + await asyncio.sleep(0.01) # Allow background task to start + await response.wait_for_queue() + + assert len(streaming_context.adapter.sent_activities) == 1 + activity = streaming_context.adapter.sent_activities[0] + assert activity.type == "typing" + assert activity.text == "Hello " + + @pytest.mark.asyncio + async def test_queue_text_chunk_multiple(self, streaming_context): + """Test queueing multiple text chunks.""" + response = StreamingResponse(streaming_context) + + response.queue_text_chunk("Hello ") + response.queue_text_chunk("World") + response.queue_text_chunk("!") + await asyncio.sleep(0.01) # Allow background tasks to start + await response.wait_for_queue() + + # Message should accumulate + assert response.get_message() == "Hello World!" + assert len(streaming_context.adapter.sent_activities) >= 1 + + @pytest.mark.asyncio + async def test_queue_text_chunk_cancelled(self, streaming_context): + """Test queueing text chunk after cancellation.""" + response = StreamingResponse(streaming_context) + await response.cancel() + + # Should not raise, just return silently + response.queue_text_chunk("test") + await asyncio.sleep(0.1) + + @pytest.mark.asyncio + async def test_queue_text_chunk_raises_after_ended(self, streaming_context): + """Test queueing text chunk after stream ended raises error.""" + response = StreamingResponse(streaming_context) + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start + await response.end_stream() + + with pytest.raises(RuntimeError, match="stream has already ended"): + response.queue_text_chunk("more") + + @pytest.mark.asyncio + async def test_queue_text_chunk_updates_sequence(self, streaming_context): + """Test that text chunks increment sequence number.""" + response = StreamingResponse(streaming_context) + + response.queue_text_chunk("chunk1") + await asyncio.sleep(0.01) # Allow background task to start + await response.wait_for_queue() # Wait for first chunk to be sent + + response.queue_text_chunk("chunk2") + await asyncio.sleep(0.01) # Allow background task to start + await response.wait_for_queue() # Wait for second chunk to be sent + + # Updates_sent should reflect the number of activities sent + # Each chunk should result in at least one activity + assert response.updates_sent >= 2 + + +class TestEndStream: + """Test end_stream method.""" + + @pytest.mark.asyncio + async def test_end_stream_sends_final_message(self, streaming_context): + """Test end_stream sends a final message activity.""" + response = StreamingResponse(streaming_context) + + response.queue_text_chunk("Hello World") + await asyncio.sleep(0.01) # Allow background task to start + await response.end_stream() + + # Find the final message + final_activities = [ + a for a in streaming_context.adapter.sent_activities if a.type == "message" + ] + assert len(final_activities) == 1 + activity = final_activities[0] + assert activity.text == "Hello World" + assert any( + e.type == "streaminfo" and e.stream_type == "final" + for e in activity.entities + ) + + @pytest.mark.asyncio + async def test_end_stream_twice_raises_error(self, streaming_context): + """Test calling end_stream twice raises error.""" + response = StreamingResponse(streaming_context) + + await response.end_stream() + + with pytest.raises(RuntimeError, match="stream has already ended"): + await response.end_stream() + + @pytest.mark.asyncio + async def test_end_stream_with_empty_message(self, streaming_context): + """Test end_stream with no text queued.""" + response = StreamingResponse(streaming_context) + await response.end_stream() + + final_activities = [ + a for a in streaming_context.adapter.sent_activities if a.type == "message" + ] + assert len(final_activities) == 1 + assert final_activities[0].text == "end stream response" + + +class TestSetAttachments: + """Test set_attachments method.""" + + @pytest.mark.asyncio + async def test_set_attachments(self, streaming_context): + """Test setting attachments on final message.""" + from microsoft_agents.activity import Attachment + + response = StreamingResponse(streaming_context) + + attachments = [ + Attachment(content_type="text/plain", content="test attachment") + ] + response.set_attachments(attachments) + await response.end_stream() + + final_activities = [ + a for a in streaming_context.adapter.sent_activities if a.type == "message" + ] + assert len(final_activities) == 1 + assert final_activities[0].attachments == attachments + + +class TestSetCitations: + """Test set_citations method.""" + + @pytest.mark.asyncio + async def test_set_citations(self, streaming_context): + """Test setting citations.""" + response = StreamingResponse(streaming_context) + + citations = [ + Citation( + title="Source 1", + content="Content 1", + filepath="file1.txt", + url="http://example.com/1", + ), + Citation( + title="Source 2", + content="Content 2", + filepath="file2.txt", + url="http://example.com/2", + ), + ] + + response.set_citations(citations) + await asyncio.sleep(0.05) # Allow async task to complete + + assert response.citations is not None + assert len(response.citations) == 2 + assert response.citations[0].position == 1 + assert response.citations[1].position == 2 + + @pytest.mark.asyncio + async def test_set_citations_multiple_calls(self, streaming_context): + """Test setting citations multiple times appends.""" + response = StreamingResponse(streaming_context) + + citations1 = [ + Citation( + title="Source 1", + content="Content 1", + filepath="file1.txt", + url="http://example.com/1", + ) + ] + response.set_citations(citations1) + await asyncio.sleep(0.05) # Allow async task to complete + + citations2 = [ + Citation( + title="Source 2", + content="Content 2", + filepath="file2.txt", + url="http://example.com/2", + ) + ] + response.set_citations(citations2) + await asyncio.sleep(0.05) # Allow async task to complete + + assert len(response.citations) == 2 + + +class TestFeedbackAndLabels: + """Test feedback loop and AI label settings.""" + + def test_set_feedback_loop(self, turn_context): + """Test setting feedback loop.""" + response = StreamingResponse(turn_context) + response.set_feedback_loop(True) + assert response._enable_feedback_loop is True + + def test_set_feedback_loop_type(self, turn_context): + """Test setting feedback loop type.""" + response = StreamingResponse(turn_context) + response.set_feedback_loop_type("custom") + assert response._feedback_loop_type == "custom" + + def test_set_generated_by_ai_label(self, turn_context): + """Test setting generated by AI label.""" + response = StreamingResponse(turn_context) + response.set_generated_by_ai_label(True) + assert response._enable_generated_by_ai_label is True + + +class TestGetMessage: + """Test get_message method.""" + + @pytest.mark.asyncio + async def test_get_message_initial(self, turn_context): + """Test get_message returns empty string initially.""" + response = StreamingResponse(turn_context) + assert response.get_message() == "" + + @pytest.mark.asyncio + async def test_get_message_after_chunks(self, streaming_context): + """Test get_message returns accumulated text.""" + response = StreamingResponse(streaming_context) + + response.queue_text_chunk("Hello ") + response.queue_text_chunk("World") + + # Message is updated synchronously, no need to wait + assert response.get_message() == "Hello World" + + +class TestWaitForQueue: + """Test wait_for_queue method.""" + + @pytest.mark.asyncio + async def test_wait_for_queue_empty(self, turn_context): + """Test wait_for_queue returns immediately when queue is empty.""" + response = StreamingResponse(turn_context) + await response.wait_for_queue() # Should not hang + + @pytest.mark.asyncio + async def test_wait_for_queue_processes_all(self, streaming_context): + """Test wait_for_queue waits for all items to process.""" + response = StreamingResponse(streaming_context) + + response.queue_text_chunk("chunk1") + await asyncio.sleep(0.01) # Allow background task to start + await response.wait_for_queue() # Wait for chunk to be sent + + response.queue_text_chunk("chunk2") + await asyncio.sleep(0.01) # Allow background task to start + await response.wait_for_queue() # Wait for chunk to be sent + + response.queue_text_chunk("chunk3") + await asyncio.sleep(0.01) # Allow background task to start + await response.wait_for_queue() # Wait for chunk to be sent + await response.wait_for_queue() + + # All activities should be sent + assert len(streaming_context.adapter.sent_activities) >= 3 + + +class TestCleanupAndCancel: + """Test cleanup and cancel methods.""" + + @pytest.mark.asyncio + async def test_cleanup(self, streaming_context): + """Test cleanup method.""" + response = StreamingResponse(streaming_context) + + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start + await response.cleanup() + + # Queue should be empty after cleanup + assert response._queue.empty() + + @pytest.mark.asyncio + async def test_cancel(self, streaming_context): + """Test cancel method.""" + response = StreamingResponse(streaming_context) + + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start + await response.cancel() + + # Cancelled flag should be set + assert response._cancelled is True + + @pytest.mark.asyncio + async def test_cancel_stops_processing(self, streaming_context): + """Test cancel stops further processing.""" + response = StreamingResponse(streaming_context) + + response.queue_text_chunk("chunk1") + await asyncio.sleep(0.01) # Allow background task to start + await response.cancel() + + # Subsequent chunks should not raise, just be ignored + response.queue_text_chunk("chunk2") + await asyncio.sleep(0.1) + + +class TestContextManager: + """Test async context manager protocol.""" + + @pytest.mark.asyncio + async def test_context_manager_cleanup(self, streaming_context): + """Test context manager calls cleanup on exit.""" + async with StreamingResponse(streaming_context) as response: + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start + + # After exiting context, cleanup should have been called + assert response._queue.empty() + + @pytest.mark.asyncio + async def test_context_manager_with_exception(self, streaming_context): + """Test context manager cleans up even with exception.""" + try: + async with StreamingResponse(streaming_context) as response: + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start + raise ValueError("Test exception") + except ValueError: + pass + + # Cleanup should still have been called + assert response._queue.empty() + + +class TestStreamIdAssignment: + """Test stream ID assignment.""" + + @pytest.mark.asyncio + async def test_stream_id_assigned_after_first_send(self, streaming_context): + """Test stream_id is assigned after first activity is sent.""" + response = StreamingResponse(streaming_context) + + assert response.stream_id is None + + response.queue_text_chunk("test") + await asyncio.sleep(0.01) # Allow background task to start + await response.wait_for_queue() + + # Stream ID should be assigned after sending + assert response.stream_id is not None + assert response.stream_id.startswith("response_") + + @pytest.mark.asyncio + async def test_stream_id_consistent_across_activities(self, streaming_context): + """Test all activities in stream use same stream ID.""" + response = StreamingResponse(streaming_context) + + response.queue_text_chunk("chunk1") + await asyncio.sleep(0.01) # Allow background task to start + await response.wait_for_queue() + + first_stream_id = response.stream_id + + response.queue_text_chunk("chunk2") + await asyncio.sleep(0.01) # Allow background task to start + await response.wait_for_queue() + + assert response.stream_id == first_stream_id + + +class TestSequenceNumbers: + """Test sequence number management.""" + + @pytest.mark.asyncio + async def test_sequence_numbers_increment(self, streaming_context): + """Test sequence numbers increment correctly.""" + response = StreamingResponse(streaming_context) + + response.queue_informative_update("info1") + response.queue_text_chunk("chunk1") + response.queue_text_chunk("chunk2") + await asyncio.sleep(0.01) # Allow background tasks to start + await response.wait_for_queue() + + # Each activity should have increasing sequence numbers + for i, activity in enumerate(streaming_context.adapter.sent_activities): + stream_entity = next( + (e for e in activity.entities if e.type == "streaminfo"), None + ) + assert stream_entity is not None + assert stream_entity.stream_sequence == i + 1