-
Notifications
You must be signed in to change notification settings - Fork 579
Add parallel reading support to S3SessionManager.list_messages() #1186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0d154f9
15f9bd4
69ebb61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import json | ||
| import logging | ||
| from concurrent.futures import ThreadPoolExecutor, as_completed | ||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast | ||
|
|
||
| import boto3 | ||
|
|
@@ -23,6 +24,7 @@ | |
| AGENT_PREFIX = "agent_" | ||
| MESSAGE_PREFIX = "message_" | ||
| MULTI_AGENT_PREFIX = "multi_agent_" | ||
| DEFAULT_READ_THREAD_COUNT = 1 | ||
|
|
||
|
|
||
| class S3SessionManager(RepositorySessionManager, SessionRepository): | ||
|
|
@@ -50,6 +52,7 @@ def __init__( | |
| boto_session: Optional[boto3.Session] = None, | ||
| boto_client_config: Optional[BotocoreConfig] = None, | ||
| region_name: Optional[str] = None, | ||
| max_parallel_reads: int = DEFAULT_READ_THREAD_COUNT, | ||
| **kwargs: Any, | ||
| ): | ||
| """Initialize S3SessionManager with S3 storage. | ||
|
|
@@ -62,11 +65,17 @@ def __init__( | |
| boto_session: Optional boto3 session | ||
| boto_client_config: Optional boto3 client configuration | ||
| region_name: AWS region for S3 storage | ||
| max_parallel_reads: Maximum number of parallel S3 read operations for list_messages(). | ||
| Defaults to 1 (sequential) for backward compatibility and safety. | ||
| Set to a higher value (e.g., 10) for better performance with many messages. | ||
| Can be overridden per-call via list_messages() kwargs. | ||
| **kwargs: Additional keyword arguments for future extensibility. | ||
| """ | ||
| self.bucket = bucket | ||
| self.prefix = prefix | ||
|
|
||
| self.max_parallel_reads = max_parallel_reads | ||
|
|
||
| session = boto_session or boto3.Session(region_name=region_name) | ||
|
|
||
| # Add strands-agents to the request user agent | ||
|
|
@@ -259,7 +268,24 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio | |
| def list_messages( | ||
| self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any | ||
| ) -> List[SessionMessage]: | ||
| """List messages for an agent with pagination from S3.""" | ||
| """List messages for an agent with pagination from S3. | ||
|
|
||
| Args: | ||
| session_id: ID of the session | ||
| agent_id: ID of the agent | ||
| limit: Optional limit on number of messages to return | ||
| offset: Optional offset for pagination | ||
| **kwargs: Additional keyword arguments. Supports: | ||
|
|
||
| - max_parallel_reads: Override the instance-level max_parallel_reads setting | ||
|
|
||
| Returns: | ||
| List of SessionMessage objects, sorted by message_id. | ||
|
|
||
| Raises: | ||
| ValueError: If max_parallel_reads override is not a positive integer. | ||
| SessionException: If S3 error occurs during message retrieval. | ||
| """ | ||
| messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" | ||
| try: | ||
| paginator = self.client.get_paginator("list_objects_v2") | ||
|
|
@@ -287,10 +313,42 @@ def list_messages( | |
| else: | ||
| message_keys = message_keys[offset:] | ||
|
|
||
| # Load only the required message objects | ||
| # Load message objects in parallel for better performance | ||
| messages: List[SessionMessage] = [] | ||
| for key in message_keys: | ||
| message_data = self._read_s3_object(key) | ||
| if not message_keys: | ||
| return messages | ||
|
|
||
| # Use ThreadPoolExecutor to fetch messages concurrently | ||
| # Allow per-call override of max_parallel_reads via kwargs, otherwise use instance default | ||
| max_workers = min(kwargs.get("max_parallel_reads", self.max_parallel_reads), len(message_keys)) | ||
|
|
||
| # Optimize for single worker case - avoid thread pool overhead | ||
| if max_workers == 1: | ||
| for key in message_keys: | ||
| message_data = self._read_s3_object(key) | ||
| if message_data: | ||
| messages.append(SessionMessage.from_dict(message_data)) | ||
| return messages | ||
|
|
||
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of a pool of threads, can we instead create one thread and use an We had a similar approach here where an async function was created to make the s3 call. This ended up getting reverted because of a bug when calling To fix that issue, you can just create a new thread, and then schedule the asyncio calls on that new thread:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for this detailed context. @zastrowm also mentioned ir. I think there are several reasons
|
||
| # Submit all read tasks | ||
| future_to_key = {executor.submit(self._read_s3_object, key): key for key in message_keys} | ||
|
|
||
| # Create a mapping from key to index to maintain order | ||
| key_to_index = {key: idx for idx, key in enumerate(message_keys)} | ||
|
|
||
| # Initialize results list with None placeholders to maintain order | ||
| results: List[Optional[Dict[str, Any]]] = [None] * len(message_keys) | ||
|
|
||
| # Process results as they complete | ||
| for future in as_completed(future_to_key): | ||
| key = future_to_key[future] | ||
| message_data = future.result() | ||
| # Store result at the correct index to maintain order | ||
| results[key_to_index[key]] = message_data | ||
|
|
||
| # Convert results to SessionMessage objects, filtering out None values | ||
| for message_data in results: | ||
| if message_data: | ||
| messages.append(SessionMessage.from_dict(message_data)) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.