diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 7d081cf09..b773e9511 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -2,6 +2,7 @@ import json import logging +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast import boto3 @@ -23,6 +24,7 @@ AGENT_PREFIX = "agent_" MESSAGE_PREFIX = "message_" MULTI_AGENT_PREFIX = "multi_agent_" +DEFAULT_READ_THREAD_COUNT = 1 class S3SessionManager(RepositorySessionManager, SessionRepository): @@ -50,6 +52,7 @@ def __init__( boto_session: Optional[boto3.Session] = None, boto_client_config: Optional[BotocoreConfig] = None, region_name: Optional[str] = None, + max_parallel_reads: int = DEFAULT_READ_THREAD_COUNT, **kwargs: Any, ): """Initialize S3SessionManager with S3 storage. @@ -62,11 +65,17 @@ def __init__( boto_session: Optional boto3 session boto_client_config: Optional boto3 client configuration region_name: AWS region for S3 storage + max_parallel_reads: Maximum number of parallel S3 read operations for list_messages(). + Defaults to 1 (sequential) for backward compatibility and safety. + Set to a higher value (e.g., 10) for better performance with many messages. + Can be overridden per-call via list_messages() kwargs. **kwargs: Additional keyword arguments for future extensibility. """ self.bucket = bucket self.prefix = prefix + self.max_parallel_reads = max_parallel_reads + session = boto_session or boto3.Session(region_name=region_name) # Add strands-agents to the request user agent @@ -259,7 +268,24 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio def list_messages( self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any ) -> List[SessionMessage]: - """List messages for an agent with pagination from S3.""" + """List messages for an agent with pagination from S3. + + Args: + session_id: ID of the session + agent_id: ID of the agent + limit: Optional limit on number of messages to return + offset: Optional offset for pagination + **kwargs: Additional keyword arguments. Supports: + + - max_parallel_reads: Override the instance-level max_parallel_reads setting + + Returns: + List of SessionMessage objects, sorted by message_id. + + Raises: + ValueError: If max_parallel_reads override is not a positive integer. + SessionException: If S3 error occurs during message retrieval. + """ messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" try: paginator = self.client.get_paginator("list_objects_v2") @@ -287,10 +313,42 @@ def list_messages( else: message_keys = message_keys[offset:] - # Load only the required message objects + # Load message objects in parallel for better performance messages: List[SessionMessage] = [] - for key in message_keys: - message_data = self._read_s3_object(key) + if not message_keys: + return messages + + # Use ThreadPoolExecutor to fetch messages concurrently + # Allow per-call override of max_parallel_reads via kwargs, otherwise use instance default + max_workers = min(kwargs.get("max_parallel_reads", self.max_parallel_reads), len(message_keys)) + + # Optimize for single worker case - avoid thread pool overhead + if max_workers == 1: + for key in message_keys: + message_data = self._read_s3_object(key) + if message_data: + messages.append(SessionMessage.from_dict(message_data)) + return messages + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all read tasks + future_to_key = {executor.submit(self._read_s3_object, key): key for key in message_keys} + + # Create a mapping from key to index to maintain order + key_to_index = {key: idx for idx, key in enumerate(message_keys)} + + # Initialize results list with None placeholders to maintain order + results: List[Optional[Dict[str, Any]]] = [None] * len(message_keys) + + # Process results as they complete + for future in as_completed(future_to_key): + key = future_to_key[future] + message_data = future.result() + # Store result at the correct index to maintain order + results[key_to_index[key]] = message_data + + # Convert results to SessionMessage objects, filtering out None values + for message_data in results: if message_data: messages.append(SessionMessage.from_dict(message_data)) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 719fbc2c9..67f9f2311 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -1,7 +1,7 @@ """Tests for S3SessionManager.""" import json -from unittest.mock import Mock +from unittest.mock import Mock, patch import boto3 import pytest @@ -308,6 +308,279 @@ def test_list_messages_with_pagination(s3_manager, sample_session, sample_agent) assert len(result) == 5 +def test_list_messages_default_max_parallel_reads(mocked_aws, s3_bucket, sample_session, sample_agent): + """Test that default max_parallel_reads is 1 (sequential for backward compatibility).""" + manager = S3SessionManager(session_id="test", bucket=s3_bucket, region_name="us-west-2") + assert manager.max_parallel_reads == 1 + + +def test_list_messages_instance_level_max_parallel_reads(mocked_aws, s3_bucket, sample_session, sample_agent): + """Test instance-level max_parallel_reads configuration.""" + manager = S3SessionManager(session_id="test", bucket=s3_bucket, region_name="us-west-2", max_parallel_reads=5) + assert manager.max_parallel_reads == 5 + + # Create session and agent + manager.create_session(sample_session) + manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for index in range(20): + message = SessionMessage.from_message( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {index}")], + }, + index=index, + ) + manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # Verify list_messages works with custom max_parallel_reads + result = manager.list_messages(sample_session.session_id, sample_agent.agent_id) + assert len(result) == 20 + # Verify messages are in correct order + for i, msg in enumerate(result): + assert msg.message_id == i + + +def test_list_messages_per_call_override_max_parallel_reads(mocked_aws, s3_bucket, sample_session, sample_agent): + """Test per-call override of max_parallel_reads via kwargs.""" + manager = S3SessionManager(session_id="test", bucket=s3_bucket, region_name="us-west-2", max_parallel_reads=20) + assert manager.max_parallel_reads == 20 + + # Create session and agent + manager.create_session(sample_session) + manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for index in range(15): + message = SessionMessage.from_message( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {index}")], + }, + index=index, + ) + manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # Override max_parallel_reads for this call + result = manager.list_messages(sample_session.session_id, sample_agent.agent_id, max_parallel_reads=3) + assert len(result) == 15 + # Verify messages are in correct order + for i, msg in enumerate(result): + assert msg.message_id == i + + +def test_list_messages_max_parallel_reads_with_few_messages(mocked_aws, s3_bucket, sample_session, sample_agent): + """Test that max_parallel_reads is capped by number of messages.""" + manager = S3SessionManager(session_id="test", bucket=s3_bucket, region_name="us-west-2", max_parallel_reads=100) + + # Create session and agent + manager.create_session(sample_session) + manager.create_agent(sample_session.session_id, sample_agent) + + # Create only 3 messages + for index in range(3): + message = SessionMessage.from_message( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {index}")], + }, + index=index, + ) + manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # Should work correctly even with max_parallel_reads > number of messages + result = manager.list_messages(sample_session.session_id, sample_agent.agent_id) + assert len(result) == 3 + + +def test_list_messages_max_parallel_reads_with_many_messages(mocked_aws, s3_bucket, sample_session, sample_agent): + """Test max_parallel_reads with a large number of messages.""" + manager = S3SessionManager(session_id="test", bucket=s3_bucket, region_name="us-west-2", max_parallel_reads=5) + + # Create session and agent + manager.create_session(sample_session) + manager.create_agent(sample_session.session_id, sample_agent) + + # Create 50 messages + for index in range(50): + message = SessionMessage.from_message( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {index}")], + }, + index=index, + ) + manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # Should work correctly with max_parallel_reads < number of messages + result = manager.list_messages(sample_session.session_id, sample_agent.agent_id) + assert len(result) == 50 + # Verify messages are in correct order + for i, msg in enumerate(result): + assert msg.message_id == i + + +def test_list_messages_max_parallel_reads_with_pagination(mocked_aws, s3_bucket, sample_session, sample_agent): + """Test max_parallel_reads works correctly with pagination.""" + manager = S3SessionManager(session_id="test", bucket=s3_bucket, region_name="us-west-2", max_parallel_reads=3) + + # Create session and agent + manager.create_session(sample_session) + manager.create_agent(sample_session.session_id, sample_agent) + + # Create 20 messages + for index in range(20): + message = SessionMessage.from_message( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {index}")], + }, + index=index, + ) + manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # Test with limit + result = manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=5, max_parallel_reads=2) + assert len(result) == 5 + assert result[0].message_id == 0 + assert result[4].message_id == 4 + + # Test with offset + result = manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=10, max_parallel_reads=4) + assert len(result) == 10 + assert result[0].message_id == 10 + assert result[9].message_id == 19 + + +@patch("strands.session.s3_session_manager.as_completed") +@patch("strands.session.s3_session_manager.ThreadPoolExecutor") +def test_list_messages_uses_correct_max_workers( + mock_thread_pool_executor, mock_as_completed, mocked_aws, s3_bucket, sample_session, sample_agent +): + """Test that ThreadPoolExecutor is called with correct max_workers value.""" + from concurrent.futures import Future + + # Create a mock executor that tracks the max_workers value and returns futures + mock_executor_instance = Mock() + mock_thread_pool_executor.return_value.__enter__.return_value = mock_executor_instance + mock_thread_pool_executor.return_value.__exit__.return_value = None + + # Track futures for as_completed + futures_list = [] + + # Mock submit to return futures that complete immediately with message data + def mock_submit(func, key): + future = Future() + # Call the actual _read_s3_object function to get real data + try: + result = func(key) + future.set_result(result) + except Exception as e: + future.set_exception(e) + futures_list.append(future) + return future + + mock_executor_instance.submit.side_effect = mock_submit + # Mock as_completed to return the futures + mock_as_completed.side_effect = lambda futures: iter(futures) + + manager = S3SessionManager(session_id="test", bucket=s3_bucket, region_name="us-west-2", max_parallel_reads=7) + + # Create session and agent + manager.create_session(sample_session) + manager.create_agent(sample_session.session_id, sample_agent) + + # Create 15 messages + for index in range(15): + message = SessionMessage.from_message( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {index}")], + }, + index=index, + ) + manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # Call list_messages + futures_list.clear() + manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + # Verify ThreadPoolExecutor was called with max_workers=7 (instance default) + mock_thread_pool_executor.assert_called_once() + call_kwargs = mock_thread_pool_executor.call_args[1] + assert call_kwargs["max_workers"] == 7 + + # Reset and test per-call override + mock_thread_pool_executor.reset_mock() + mock_as_completed.reset_mock() + futures_list.clear() + manager.list_messages(sample_session.session_id, sample_agent.agent_id, max_parallel_reads=3) + + # Verify ThreadPoolExecutor was called with max_workers=3 (per-call override) + mock_thread_pool_executor.assert_called_once() + call_kwargs = mock_thread_pool_executor.call_args[1] + assert call_kwargs["max_workers"] == 3 + + +@patch("strands.session.s3_session_manager.as_completed") +@patch("strands.session.s3_session_manager.ThreadPoolExecutor") +def test_list_messages_max_workers_capped_by_message_count( + mock_thread_pool_executor, mock_as_completed, mocked_aws, s3_bucket, sample_session, sample_agent +): + """Test that max_workers is capped by the number of messages.""" + from concurrent.futures import Future + + # Create a mock executor that tracks the max_workers value and returns futures + mock_executor_instance = Mock() + mock_thread_pool_executor.return_value.__enter__.return_value = mock_executor_instance + mock_thread_pool_executor.return_value.__exit__.return_value = None + + # Track futures for as_completed + futures_list = [] + + # Mock submit to return futures that complete immediately with message data + def mock_submit(func, key): + future = Future() + # Call the actual _read_s3_object function to get real data + try: + result = func(key) + future.set_result(result) + except Exception as e: + future.set_exception(e) + futures_list.append(future) + return future + + mock_executor_instance.submit.side_effect = mock_submit + # Mock as_completed to return the futures + mock_as_completed.side_effect = lambda futures: iter(futures) + + manager = S3SessionManager(session_id="test", bucket=s3_bucket, region_name="us-west-2", max_parallel_reads=100) + + # Create session and agent + manager.create_session(sample_session) + manager.create_agent(sample_session.session_id, sample_agent) + + # Create only 5 messages + for index in range(5): + message = SessionMessage.from_message( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {index}")], + }, + index=index, + ) + manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # Call list_messages + manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + # Verify ThreadPoolExecutor was called with max_workers=5 (capped by message count) + mock_thread_pool_executor.assert_called_once() + call_kwargs = mock_thread_pool_executor.call_args[1] + assert call_kwargs["max_workers"] == 5 + + def test_update_message(s3_manager, sample_session, sample_agent, sample_message): """Test updating a message in S3.""" # Create session, agent, and message