From 4768beb8490181bb83b838fc40974bdd28cb6c22 Mon Sep 17 00:00:00 2001 From: Steven Leggett Date: Fri, 9 Jan 2026 16:02:46 -0500 Subject: [PATCH] feat: Add cloud session storage for horizontal scaling Implements SessionStorage protocol with S3 and GCS reference implementations to enable cloud persistence for session transcripts. This supports: - Horizontal scaling across multiple servers with shared sessions - Ephemeral filesystem environments (containers, serverless) - Extensible architecture for custom backends (Azure Blob, etc.) Key changes: - Add SessionStorage protocol and BaseSessionStorage ABC - Add S3SessionStorage with support for S3-compatible services (DigitalOcean Spaces, Cloudflare R2, MinIO) - Add GCSSessionStorage for Google Cloud Storage - Add SessionSyncManager for automatic sync via hooks - Add session_storage field to ClaudeAgentOptions - Add optional dependencies: [s3], [gcs], [cloud] - Add SessionStorageError for storage failures - Add 60 unit tests for session storage - Add examples for basic usage and caching patterns - Update README with session storage documentation Closes #432 --- README.md | 97 +- examples/session_storage_cached.py | 670 +++++++++++ examples/session_storage_example.py | 467 ++++++++ pyproject.toml | 12 + src/claude_agent_sdk/__init__.py | 2 + src/claude_agent_sdk/_errors.py | 42 + src/claude_agent_sdk/client.py | 33 + .../session_storage/__init__.py | 86 ++ src/claude_agent_sdk/session_storage/_base.py | 333 ++++++ src/claude_agent_sdk/session_storage/_gcs.py | 365 ++++++ .../session_storage/_protocol.py | 151 +++ src/claude_agent_sdk/session_storage/_s3.py | 414 +++++++ src/claude_agent_sdk/session_storage/_sync.py | 187 +++ src/claude_agent_sdk/types.py | 12 + tests/test_session_storage.py | 1038 +++++++++++++++++ 15 files changed, 3904 insertions(+), 5 deletions(-) create mode 100644 examples/session_storage_cached.py create mode 100644 examples/session_storage_example.py create mode 100644 src/claude_agent_sdk/session_storage/__init__.py create mode 100644 src/claude_agent_sdk/session_storage/_base.py create mode 100644 src/claude_agent_sdk/session_storage/_gcs.py create mode 100644 src/claude_agent_sdk/session_storage/_protocol.py create mode 100644 src/claude_agent_sdk/session_storage/_s3.py create mode 100644 src/claude_agent_sdk/session_storage/_sync.py create mode 100644 tests/test_session_storage.py diff --git a/README.md b/README.md index 790c9249..9efea5a3 100644 --- a/README.md +++ b/README.md @@ -233,6 +233,90 @@ async with ClaudeSDKClient(options=options) as client: print(msg) ``` +## Session Storage (Cloud Persistence) + +Session storage enables cloud persistence for session transcripts, supporting: +- **Horizontal scaling** across multiple servers with shared sessions +- **Ephemeral filesystems** (containers, serverless) +- **Extensible backends** (S3, GCS, or custom implementations) + +> **WARNING:** Cloud storage operations add latency (50-500ms per operation). For production at scale, consider wrapping with a caching layer. + +### Installation + +```bash +# For AWS S3 (also works with S3-compatible: DigitalOcean Spaces, Cloudflare R2, MinIO) +pip install claude-agent-sdk[s3] + +# For Google Cloud Storage +pip install claude-agent-sdk[gcs] + +# For both +pip install claude-agent-sdk[cloud] +``` + +### Basic Usage + +```python +from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient +from claude_agent_sdk.session_storage import S3SessionStorage, S3Config + +# Configure S3 storage +storage = S3SessionStorage(S3Config( + bucket="my-sessions", + prefix="claude", + region="us-east-1", +)) + +# Sessions automatically sync to S3 +options = ClaudeAgentOptions(session_storage=storage) + +async with ClaudeSDKClient(options=options) as client: + await client.query("Hello!") + async for msg in client.receive_response(): + print(msg) + # Transcript automatically uploaded to S3 on session end + +# Resume from S3 (on any server) +options = ClaudeAgentOptions( + session_storage=storage, + resume="session-abc123", # Downloads from S3 first +) +``` + +### S3-Compatible Services + +```python +# DigitalOcean Spaces +storage = S3SessionStorage(S3Config( + bucket="my-bucket", + endpoint_url="https://nyc3.digitaloceanspaces.com", + region="nyc3", +)) + +# Cloudflare R2 +storage = S3SessionStorage(S3Config( + bucket="my-bucket", + endpoint_url="https://.r2.cloudflarestorage.com", +)) +``` + +### Custom Backends + +Implement the `SessionStorage` protocol for custom backends (Azure Blob, etc.): + +```python +from claude_agent_sdk.session_storage import BaseSessionStorage + +class AzureBlobSessionStorage(BaseSessionStorage): + async def _do_upload(self, key, local_path): + # Your Azure Blob SDK calls + ... + # Implement other abstract methods +``` + +See [examples/session_storage_example.py](examples/session_storage_example.py) for more examples and [examples/session_storage_cached.py](examples/session_storage_cached.py) for caching patterns. + ## Types See [src/claude_agent_sdk/types.py](src/claude_agent_sdk/types.py) for complete type definitions: @@ -245,11 +329,12 @@ See [src/claude_agent_sdk/types.py](src/claude_agent_sdk/types.py) for complete ```python from claude_agent_sdk import ( - ClaudeSDKError, # Base error - CLINotFoundError, # Claude Code not installed - CLIConnectionError, # Connection issues - ProcessError, # Process failed - CLIJSONDecodeError, # JSON parsing issues + ClaudeSDKError, # Base error + CLINotFoundError, # Claude Code not installed + CLIConnectionError, # Connection issues + ProcessError, # Process failed + CLIJSONDecodeError, # JSON parsing issues + SessionStorageError, # Cloud storage failures ) try: @@ -275,6 +360,8 @@ See [examples/quick_start.py](examples/quick_start.py) for a complete working ex See [examples/streaming_mode.py](examples/streaming_mode.py) for comprehensive examples involving `ClaudeSDKClient`. You can even run interactive examples in IPython from [examples/streaming_mode_ipython.py](examples/streaming_mode_ipython.py). +See [examples/session_storage_example.py](examples/session_storage_example.py) for cloud session storage examples and [examples/session_storage_cached.py](examples/session_storage_cached.py) for caching patterns. + ## Migrating from Claude Code SDK If you're upgrading from the Claude Code SDK (versions < 0.1.0), please see the [CHANGELOG.md](CHANGELOG.md#010) for details on breaking changes and new features, including: diff --git a/examples/session_storage_cached.py b/examples/session_storage_cached.py new file mode 100644 index 00000000..8333717e --- /dev/null +++ b/examples/session_storage_cached.py @@ -0,0 +1,670 @@ +#!/usr/bin/env python3 +"""Caching patterns for session storage in production. + +This file demonstrates how to build caching wrappers around session storage +backends to reduce latency and cost in production environments. + +The SDK provides primitive session storage implementations (S3, GCS) that +directly interact with cloud storage. This "batteries included but removable" +philosophy lets you add caching optimized for your specific needs. + +Caching strategies shown: +- Local file cache (simple, works anywhere) +- Redis cache (distributed, recommended for production) +- LRU memory cache (fast, but process-local) + +Installation: + # Base session storage + pip install claude-agent-sdk[s3] + + # For Redis examples + pip install redis + +Why cache? +- Direct S3/GCS operations: 50-500ms+ latency per operation +- With local cache: <1ms for cache hits +- Cost savings: Fewer cloud storage API calls + +WARNING: Caching adds complexity. Only add it when you have: +1. High request volume (>100 requests/min) +2. Measured latency problems +3. Monitoring to track cache hit rates +""" + +import asyncio +import logging +import time +from pathlib import Path +from typing import Any + +from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient +from claude_agent_sdk.session_storage import SessionMetadata, SessionStorage +from claude_agent_sdk.types import AssistantMessage, ResultMessage, TextBlock + +logger = logging.getLogger(__name__) + + +def display_message(msg): + """Display messages in a standardized format.""" + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif isinstance(msg, ResultMessage): + print("Result ended") + + +# ============================================================================ +# PATTERN 1: Local File Cache +# ============================================================================ + + +class LocalFileCachedStorage: + """File-based cache wrapper for session storage. + + This is the simplest caching strategy - keeps a local file cache of + recently accessed transcripts. Perfect for single-server deployments + or development environments. + + Benefits: + - Simple to implement (no dependencies) + - Works on any filesystem + - Survives process restarts + + Limitations: + - Not shared across servers + - No automatic eviction (manual cleanup needed) + - File I/O overhead (still faster than S3) + + Example: + >>> from claude_agent_sdk.session_storage import S3SessionStorage, S3Config + >>> backend = S3SessionStorage(S3Config(bucket="my-bucket")) + >>> cached = LocalFileCachedStorage(backend, cache_dir="/tmp/session-cache") + >>> options = ClaudeAgentOptions(session_storage=cached) + """ + + def __init__( + self, backend: SessionStorage, cache_dir: str | Path = "/tmp/session-cache" + ): + """Initialize file cache wrapper. + + Args: + backend: Underlying storage backend (S3, GCS, etc). + cache_dir: Directory to store cached transcripts. + """ + self.backend = backend + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Initialized file cache at: {self.cache_dir}") + + def _cache_path(self, session_id: str) -> Path: + """Get cache file path for a session.""" + return self.cache_dir / f"{session_id}.jsonl" + + async def upload_transcript(self, session_id: str, local_path: Path | str) -> str: + """Upload transcript and update cache.""" + # Upload to backend + result = await self.backend.upload_transcript(session_id, local_path) + + # Update local cache + cache_path = self._cache_path(session_id) + local_path = Path(local_path) + if local_path.exists(): + # Copy to cache + cache_path.write_bytes(local_path.read_bytes()) + logger.debug(f"Cached transcript for {session_id}") + + return result + + async def download_transcript( + self, session_id: str, local_path: Path | str + ) -> bool: + """Download transcript, using cache if available.""" + cache_path = self._cache_path(session_id) + local_path = Path(local_path) + + # Check cache first + if cache_path.exists(): + logger.info(f"Cache HIT for {session_id}") + local_path.write_bytes(cache_path.read_bytes()) + return True + + # Cache miss - fetch from backend + logger.info(f"Cache MISS for {session_id}") + success = await self.backend.download_transcript(session_id, local_path) + + if success: + # Populate cache + cache_path.write_bytes(local_path.read_bytes()) + logger.debug(f"Populated cache for {session_id}") + + return success + + async def exists(self, session_id: str) -> bool: + """Check if session exists (cache or backend).""" + # Check cache first + if self._cache_path(session_id).exists(): + return True + return await self.backend.exists(session_id) + + async def delete(self, session_id: str) -> bool: + """Delete from both cache and backend.""" + # Remove from cache + cache_path = self._cache_path(session_id) + if cache_path.exists(): + cache_path.unlink() + + # Remove from backend + return await self.backend.delete(session_id) + + async def list_sessions( + self, prefix: str | None = None, limit: int = 100 + ) -> list[SessionMetadata]: + """List sessions from backend (cache doesn't affect listing).""" + return await self.backend.list_sessions(prefix, limit) + + async def get_metadata(self, session_id: str) -> SessionMetadata | None: + """Get metadata from backend.""" + return await self.backend.get_metadata(session_id) + + def clear_cache(self) -> int: + """Clear all cached files. + + Returns: + Number of files removed. + """ + count = 0 + for cache_file in self.cache_dir.glob("*.jsonl"): + cache_file.unlink() + count += 1 + logger.info(f"Cleared {count} cached files") + return count + + +# ============================================================================ +# PATTERN 2: Redis Cache (Production-Ready) +# ============================================================================ + + +class RedisCachedStorage: + """Redis-based cache wrapper for session storage. + + This is the recommended production caching strategy. Redis provides: + - Distributed cache shared across all servers + - Automatic TTL-based eviction + - High performance (sub-millisecond access) + - Built-in memory management + + Benefits: + - Shared across all servers (consistent cache) + - Automatic expiration (no manual cleanup) + - Very fast (in-memory) + - Production-proven + + Limitations: + - Requires Redis server + - Additional infrastructure cost + - Memory constraints + + Example: + >>> import redis.asyncio as redis + >>> from claude_agent_sdk.session_storage import S3SessionStorage, S3Config + >>> + >>> backend = S3SessionStorage(S3Config(bucket="my-bucket")) + >>> redis_client = await redis.from_url("redis://localhost") + >>> cached = RedisCachedStorage(backend, redis_client) + >>> options = ClaudeAgentOptions(session_storage=cached) + """ + + def __init__( + self, + backend: SessionStorage, + redis_client: Any, # redis.asyncio.Redis + ttl: int = 3600, + key_prefix: str = "claude:session:", + ): + """Initialize Redis cache wrapper. + + Args: + backend: Underlying storage backend (S3, GCS, etc). + redis_client: Redis client instance (redis.asyncio.Redis). + ttl: Cache TTL in seconds (default: 1 hour). + key_prefix: Prefix for Redis keys. + """ + self.backend = backend + self.redis = redis_client + self.ttl = ttl + self.key_prefix = key_prefix + logger.info(f"Initialized Redis cache with TTL={ttl}s, prefix={key_prefix!r}") + + def _cache_key(self, session_id: str) -> str: + """Get Redis key for a session.""" + return f"{self.key_prefix}{session_id}" + + async def upload_transcript(self, session_id: str, local_path: Path | str) -> str: + """Upload transcript and update cache.""" + # Upload to backend + result = await self.backend.upload_transcript(session_id, local_path) + + # Update Redis cache + local_path = Path(local_path) + if local_path.exists(): + cache_key = self._cache_key(session_id) + content = local_path.read_bytes() + await self.redis.setex(cache_key, self.ttl, content) + logger.debug(f"Cached transcript for {session_id} in Redis") + + return result + + async def download_transcript( + self, session_id: str, local_path: Path | str + ) -> bool: + """Download transcript, using Redis cache if available.""" + cache_key = self._cache_key(session_id) + local_path = Path(local_path) + + # Check Redis first + cached_content = await self.redis.get(cache_key) + if cached_content: + logger.info(f"Redis cache HIT for {session_id}") + local_path.write_bytes(cached_content) + return True + + # Cache miss - fetch from backend + logger.info(f"Redis cache MISS for {session_id}") + success = await self.backend.download_transcript(session_id, local_path) + + if success: + # Populate Redis cache + content = local_path.read_bytes() + await self.redis.setex(cache_key, self.ttl, content) + logger.debug(f"Populated Redis cache for {session_id}") + + return success + + async def exists(self, session_id: str) -> bool: + """Check if session exists (cache or backend).""" + # Check Redis first + cache_key = self._cache_key(session_id) + if await self.redis.exists(cache_key): + return True + return await self.backend.exists(session_id) + + async def delete(self, session_id: str) -> bool: + """Delete from both cache and backend.""" + # Remove from Redis + cache_key = self._cache_key(session_id) + await self.redis.delete(cache_key) + + # Remove from backend + return await self.backend.delete(session_id) + + async def list_sessions( + self, prefix: str | None = None, limit: int = 100 + ) -> list[SessionMetadata]: + """List sessions from backend.""" + return await self.backend.list_sessions(prefix, limit) + + async def get_metadata(self, session_id: str) -> SessionMetadata | None: + """Get metadata from backend.""" + return await self.backend.get_metadata(session_id) + + async def clear_cache(self) -> int: + """Clear all cached sessions from Redis. + + Returns: + Number of keys removed. + """ + pattern = f"{self.key_prefix}*" + count = 0 + async for key in self.redis.scan_iter(match=pattern): + await self.redis.delete(key) + count += 1 + logger.info(f"Cleared {count} keys from Redis cache") + return count + + +# ============================================================================ +# PATTERN 3: LRU Memory Cache (Simple, Process-Local) +# ============================================================================ + + +class LRUMemoryCachedStorage: + """LRU in-memory cache wrapper for session storage. + + This strategy keeps recently used sessions in process memory using an + LRU (Least Recently Used) eviction policy. Fastest possible cache, + but not shared across processes/servers. + + Benefits: + - Extremely fast (memory access) + - No external dependencies + - Simple implementation + + Limitations: + - Not shared across servers/processes + - Lost on process restart + - Memory constrained + + Best for: + - Single-server deployments + - Development/testing + - When you have spare memory + + Example: + >>> from claude_agent_sdk.session_storage import S3SessionStorage, S3Config + >>> backend = S3SessionStorage(S3Config(bucket="my-bucket")) + >>> cached = LRUMemoryCachedStorage(backend, max_size=100) + >>> options = ClaudeAgentOptions(session_storage=cached) + """ + + def __init__(self, backend: SessionStorage, max_size: int = 100): + """Initialize LRU memory cache wrapper. + + Args: + backend: Underlying storage backend. + max_size: Maximum number of sessions to cache. + """ + self.backend = backend + self.max_size = max_size + # Cache format: {session_id: (content_bytes, access_time)} + self.cache: dict[str, tuple[bytes, float]] = {} + logger.info(f"Initialized LRU memory cache with max_size={max_size}") + + def _evict_if_needed(self) -> None: + """Evict least recently used item if cache is full.""" + if len(self.cache) >= self.max_size: + # Find LRU item + lru_session = min(self.cache.items(), key=lambda x: x[1][1]) + del self.cache[lru_session[0]] + logger.debug(f"Evicted {lru_session[0]} from LRU cache") + + async def upload_transcript(self, session_id: str, local_path: Path | str) -> str: + """Upload transcript and update cache.""" + # Upload to backend + result = await self.backend.upload_transcript(session_id, local_path) + + # Update memory cache + local_path = Path(local_path) + if local_path.exists(): + self._evict_if_needed() + content = local_path.read_bytes() + self.cache[session_id] = (content, time.time()) + logger.debug(f"Cached transcript for {session_id} in memory") + + return result + + async def download_transcript( + self, session_id: str, local_path: Path | str + ) -> bool: + """Download transcript, using memory cache if available.""" + local_path = Path(local_path) + + # Check memory cache first + if session_id in self.cache: + logger.info(f"Memory cache HIT for {session_id}") + content, _ = self.cache[session_id] + # Update access time + self.cache[session_id] = (content, time.time()) + local_path.write_bytes(content) + return True + + # Cache miss - fetch from backend + logger.info(f"Memory cache MISS for {session_id}") + success = await self.backend.download_transcript(session_id, local_path) + + if success: + # Populate memory cache + self._evict_if_needed() + content = local_path.read_bytes() + self.cache[session_id] = (content, time.time()) + logger.debug(f"Populated memory cache for {session_id}") + + return success + + async def exists(self, session_id: str) -> bool: + """Check if session exists.""" + if session_id in self.cache: + return True + return await self.backend.exists(session_id) + + async def delete(self, session_id: str) -> bool: + """Delete from both cache and backend.""" + # Remove from memory + self.cache.pop(session_id, None) + # Remove from backend + return await self.backend.delete(session_id) + + async def list_sessions( + self, prefix: str | None = None, limit: int = 100 + ) -> list[SessionMetadata]: + """List sessions from backend.""" + return await self.backend.list_sessions(prefix, limit) + + async def get_metadata(self, session_id: str) -> SessionMetadata | None: + """Get metadata from backend.""" + return await self.backend.get_metadata(session_id) + + def clear_cache(self) -> int: + """Clear all cached sessions from memory.""" + count = len(self.cache) + self.cache.clear() + logger.info(f"Cleared {count} sessions from memory cache") + return count + + def get_cache_stats(self) -> dict[str, Any]: + """Get cache statistics.""" + return { + "size": len(self.cache), + "max_size": self.max_size, + "utilization": len(self.cache) / self.max_size, + } + + +# ============================================================================ +# Example Usage +# ============================================================================ + + +async def example_file_cache(): + """Demonstrate local file cache.""" + print("=== Local File Cache Example ===\n") + + from claude_agent_sdk.session_storage import S3Config, S3SessionStorage + + # Create backend + backend = S3SessionStorage( + S3Config(bucket="my-sessions", prefix="claude", region="us-east-1") + ) + + # Wrap with file cache + cached_storage = LocalFileCachedStorage(backend, cache_dir="/tmp/claude-cache") + + # Use with SDK + options = ClaudeAgentOptions(session_storage=cached_storage) + + print("First request (cache miss) - will fetch from S3:") + async with ClaudeSDKClient(options=options) as client: + await client.query("What is 2 + 2?") + async for msg in client.receive_response(): + display_message(msg) + + print("\nSecond request (cache hit) - will read from local cache:") + async with ClaudeSDKClient(options=options) as client: + await client.query("What is 3 + 3?") + async for msg in client.receive_response(): + display_message(msg) + + # Cleanup + cached_storage.clear_cache() + print("\n") + + +async def example_redis_cache_pseudocode(): + """Show Redis cache pattern (pseudocode - requires Redis).""" + print("=== Redis Cache Pattern (Pseudocode) ===\n") + + print("# Install dependencies:") + print("pip install redis") + print() + print("# Code:") + print("import redis.asyncio as redis") + print("from claude_agent_sdk.session_storage import S3SessionStorage, S3Config") + print() + print("# Connect to Redis") + print('redis_client = await redis.from_url("redis://localhost:6379")') + print() + print("# Create cached storage") + print("backend = S3SessionStorage(S3Config(bucket='my-sessions'))") + print("cached = RedisCachedStorage(") + print(" backend=backend,") + print(" redis_client=redis_client,") + print(" ttl=3600, # 1 hour TTL") + print(")") + print() + print("# Use with SDK") + print("options = ClaudeAgentOptions(session_storage=cached)") + print() + print("Benefits:") + print("- Shared cache across all servers") + print("- Automatic expiration (TTL)") + print("- Sub-millisecond access times") + print("- Production-proven reliability") + print("\n") + + +async def example_lru_cache(): + """Demonstrate LRU memory cache.""" + print("=== LRU Memory Cache Example ===\n") + + print("Configuration example:") + print() + print("from claude_agent_sdk.session_storage import S3SessionStorage, S3Config") + print() + print("# Create backend") + print("backend = S3SessionStorage(") + print(" S3Config(bucket='my-sessions', prefix='claude', region='us-east-1')") + print(")") + print() + print("# Wrap with LRU cache") + print("cached_storage = LRUMemoryCachedStorage(backend, max_size=10)") + print() + print("# Use with SDK") + print("options = ClaudeAgentOptions(session_storage=cached_storage)") + print() + print("Memory cache benefits:") + print(" - Extremely fast (memory access)") + print(" - No external dependencies") + print(" - Simple implementation") + print(" - Max size: configurable (default shown: 10 sessions)") + print() + print("Note: Perfect for single-server deployments or development.") + print("\n") + + +async def example_cache_comparison(): + """Compare different caching strategies.""" + print("=== Caching Strategy Comparison ===\n") + + print("| Strategy | Latency | Shared | Persistent | Complexity | Cost |") + print("|--------------|---------|--------|------------|------------|-----------|") + print("| No Cache | 50-500ms| N/A | Yes | Low | API calls |") + print("| File Cache | ~1-5ms | No | Yes | Low | Disk |") + print("| Redis Cache | ~0.1ms | Yes | Optional | Medium | Redis |") + print("| Memory Cache | ~0.01ms | No | No | Low | RAM |") + print() + print("Recommendations:") + print() + print("1. SINGLE SERVER / DEVELOPMENT:") + print(" -> Use LocalFileCachedStorage or LRUMemoryCachedStorage") + print(" -> Simple, no dependencies, good enough") + print() + print("2. PRODUCTION / MULTI-SERVER:") + print(" -> Use RedisCachedStorage") + print(" -> Shared cache, automatic eviction, proven at scale") + print() + print("3. SERVERLESS / CONTAINERS:") + print(" -> Use Redis or external cache service") + print(" -> File/memory caches reset on each invocation") + print() + print("4. LOW TRAFFIC (<100 req/min):") + print(" -> Don't cache! Direct S3/GCS is fine") + print(" -> Measure first, optimize if needed") + print("\n") + + +async def example_custom_cache(): + """Show how to implement a custom cache strategy.""" + print("=== Custom Cache Implementation ===\n") + + print( + "The SDK provides the SessionStorage protocol - implement it however you want:" + ) + print() + print("class MyCustomCache:") + print(' """Your custom caching logic."""') + print() + print(" def __init__(self, backend, my_cache_system):") + print(" self.backend = backend") + print(" self.cache = my_cache_system") + print() + print(" async def upload_transcript(self, session_id, local_path):") + print(" # Upload to backend") + print( + " result = await self.backend.upload_transcript(session_id, local_path)" + ) + print(" # Update your cache") + print(" await self.cache.set(session_id, local_path)") + print(" return result") + print() + print(" async def download_transcript(self, session_id, local_path):") + print(" # Try cache first") + print(" if await self.cache.has(session_id):") + print(" return await self.cache.get(session_id, local_path)") + print(" # Cache miss - fetch from backend") + print( + " return await self.backend.download_transcript(session_id, local_path)" + ) + print() + print(" # ... implement other SessionStorage methods ...") + print() + print("Examples of custom caches:") + print("- Memcached") + print("- DynamoDB") + print("- Cloudflare KV") + print("- Your own distributed cache") + print("\n") + + +async def main(): + """Run caching examples.""" + print("Claude Agent SDK - Session Storage Caching Patterns") + print("=" * 60) + print() + + # await example_file_cache() + await example_redis_cache_pseudocode() + await example_lru_cache() + await example_cache_comparison() + await example_custom_cache() + + print("=" * 60) + print() + print("Key Takeaways:") + print() + print("1. The SDK provides primitive storage implementations") + print("2. You add caching tailored to your needs") + print("3. Start simple, add caching when you measure latency problems") + print("4. Redis is the production standard for distributed caching") + print("5. File/memory caches work great for single-server deployments") + print() + + +if __name__ == "__main__": + # Set up logging to see cache hits/misses + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + asyncio.run(main()) diff --git a/examples/session_storage_example.py b/examples/session_storage_example.py new file mode 100644 index 00000000..17634877 --- /dev/null +++ b/examples/session_storage_example.py @@ -0,0 +1,467 @@ +#!/usr/bin/env python3 +"""Examples of using cloud session storage with Claude Agent SDK. + +This file demonstrates how to use S3 and GCS session storage backends to +persist session transcripts to cloud storage. This enables: + +- Horizontal scaling across multiple servers with shared sessions +- Support for ephemeral filesystems (containers, serverless) +- Session resume from cloud storage +- Persistent conversation history + +Installation: + For S3 (AWS, DigitalOcean Spaces, Cloudflare R2, MinIO): + pip install claude-agent-sdk[s3] + + For GCS (Google Cloud Storage): + pip install claude-agent-sdk[gcs] + +WARNING: Cloud storage operations add latency (50-500ms+ per operation). +For production at scale, see session_storage_cached.py for caching patterns. +""" + +import asyncio +import os + +from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient +from claude_agent_sdk.types import AssistantMessage, ResultMessage, TextBlock + + +def display_message(msg): + """Display messages in a standardized format.""" + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif isinstance(msg, ResultMessage): + print("Result ended") + + +async def example_s3_basic(): + """Basic S3 usage with AWS credentials. + + This example shows standard AWS S3 configuration. The SDK will use + standard AWS credential chain (environment vars, IAM roles, etc). + """ + print("=== S3 Basic Example ===\n") + + # Import only needed when using S3 + from claude_agent_sdk.session_storage import S3Config, S3SessionStorage + + # Configure S3 storage + storage = S3SessionStorage( + S3Config( + bucket="my-claude-sessions", + prefix="claude-sessions", # Organize sessions under this prefix + region="us-east-1", # Optional: specify AWS region + # aws_access_key_id="AKIAIOSFODNN7EXAMPLE", # Optional: explicit credentials + # aws_secret_access_key="wJalrXUtnFEMI/...", # Optional: explicit credentials + ) + ) + + # Use storage in options + options = ClaudeAgentOptions( + session_storage=storage, + # Optional: specify local transcript directory + # transcript_dir="/tmp/claude-transcripts", + ) + + # Start a conversation - transcript is synced to S3 + async with ClaudeSDKClient(options=options) as client: + print("User: What is the capital of France?") + await client.query("What is the capital of France?") + + async for msg in client.receive_response(): + display_message(msg) + + # Follow-up question - same session + print("\nUser: What's the population?") + await client.query("What's the population?") + + async for msg in client.receive_response(): + display_message(msg) + + print("\n") + + +async def example_digitalocean_spaces(): + """DigitalOcean Spaces configuration (S3-compatible). + + DigitalOcean Spaces uses the S3 API with a custom endpoint. + This pattern works for any S3-compatible service. + """ + print("=== DigitalOcean Spaces Example ===\n") + + from claude_agent_sdk.session_storage import S3Config, S3SessionStorage + + # DigitalOcean Spaces configuration + storage = S3SessionStorage( + S3Config( + bucket="my-space-name", # Your Spaces name + prefix="claude-sessions", + endpoint_url="https://nyc3.digitaloceanspaces.com", # Spaces endpoint + region="nyc3", # Spaces region + # Get these from: https://cloud.digitalocean.com/account/api/tokens + aws_access_key_id=os.getenv("DO_SPACES_KEY", "your-spaces-key"), + aws_secret_access_key=os.getenv("DO_SPACES_SECRET", "your-spaces-secret"), + ) + ) + + options = ClaudeAgentOptions(session_storage=storage) + + async with ClaudeSDKClient(options=options) as client: + print("User: Hello! Remember this: my favorite color is blue.") + await client.query("Hello! Remember this: my favorite color is blue.") + + async for msg in client.receive_response(): + display_message(msg) + + print( + "\nNote: Transcript is now stored in DigitalOcean Spaces and can be resumed from any server.\n" + ) + + +async def example_cloudflare_r2(): + """Cloudflare R2 configuration (S3-compatible, no egress fees). + + Cloudflare R2 is fully S3-compatible and has zero egress fees, + making it cost-effective for high-traffic applications. + """ + print("=== Cloudflare R2 Example ===\n") + + from claude_agent_sdk.session_storage import S3Config, S3SessionStorage + + # Cloudflare R2 configuration + storage = S3SessionStorage( + S3Config( + bucket="my-r2-bucket", + prefix="claude-sessions", + # R2 endpoint format: https://.r2.cloudflarestorage.com + endpoint_url=os.getenv( + "R2_ENDPOINT", "https://abc123.r2.cloudflarestorage.com" + ), + # Get credentials from Cloudflare dashboard > R2 > Manage R2 API Tokens + aws_access_key_id=os.getenv("R2_ACCESS_KEY_ID", "your-r2-access-key"), + aws_secret_access_key=os.getenv( + "R2_SECRET_ACCESS_KEY", "your-r2-secret-key" + ), + # R2 doesn't require region, but some S3 clients expect it + region="auto", + ) + ) + + options = ClaudeAgentOptions(session_storage=storage) + + async with ClaudeSDKClient(options=options) as client: + print("User: What's 2 + 2?") + await client.query("What's 2 + 2?") + + async for msg in client.receive_response(): + display_message(msg) + + print("\nNote: R2 has zero egress fees - cost-effective for production.\n") + + +async def example_minio(): + """MinIO configuration (self-hosted S3-compatible storage). + + MinIO is an open-source S3-compatible object storage server + that you can self-host on-premise or in your own cloud. + """ + print("=== MinIO Example ===\n") + + from claude_agent_sdk.session_storage import S3Config, S3SessionStorage + + # MinIO configuration (local or self-hosted) + _storage = S3SessionStorage( + S3Config( + bucket="claude-sessions", + prefix="sessions", + endpoint_url="http://localhost:9000", # MinIO server endpoint + aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"), + aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"), + # MinIO doesn't require region + ) + ) + + # In a real application: + # options = ClaudeAgentOptions(session_storage=storage) + + print("Note: This example assumes MinIO is running on localhost:9000") + print("To start MinIO: docker run -p 9000:9000 minio/minio server /data\n") + + # In a real application, you would use the client normally + print("Storage configured with MinIO - ready for use.\n") + + +async def example_gcs_basic(): + """Basic Google Cloud Storage usage. + + GCS uses Application Default Credentials (ADC) by default, which works + seamlessly in GCP environments (Compute Engine, GKE, Cloud Run, etc.). + """ + print("=== GCS Basic Example ===\n") + + from claude_agent_sdk.session_storage import GCSConfig, GCSSessionStorage + + # Configure GCS storage + storage = GCSSessionStorage( + GCSConfig( + bucket="my-claude-sessions", # Your GCS bucket name + prefix="claude-sessions", + project="my-gcp-project", # Optional: GCP project ID + # credentials_path="/path/to/service-account.json", # Optional: explicit credentials + ) + ) + + options = ClaudeAgentOptions(session_storage=storage) + + async with ClaudeSDKClient(options=options) as client: + print("User: What is 10 * 7?") + await client.query("What is 10 * 7?") + + async for msg in client.receive_response(): + display_message(msg) + + print("\n") + + +async def example_gcs_with_credentials(): + """GCS with explicit service account credentials. + + Use this when you need to specify credentials explicitly, + for example in local development or CI/CD environments. + """ + print("=== GCS with Service Account Example ===\n") + + from claude_agent_sdk.session_storage import GCSConfig, GCSSessionStorage + + # Configure with service account JSON file + _storage = GCSSessionStorage( + GCSConfig( + bucket="my-claude-sessions", + prefix="claude-prod", + project="my-gcp-project", + # Path to service account JSON key file + credentials_path=os.getenv( + "GOOGLE_APPLICATION_CREDENTIALS", "/path/to/service-account.json" + ), + ) + ) + + # In a real application: + # options = ClaudeAgentOptions(session_storage=storage) + + print("Storage configured with GCS service account credentials.") + print( + "Note: Get credentials from: https://console.cloud.google.com/iam-admin/serviceaccounts\n" + ) + + +async def example_session_resume(): + """Resume a session from cloud storage. + + When you provide a session_id that already exists in cloud storage, + the SDK automatically downloads the transcript and resumes the conversation. + """ + print("=== Session Resume Example ===\n") + + from claude_agent_sdk.session_storage import S3Config, S3SessionStorage + + storage = S3SessionStorage( + S3Config( + bucket="my-claude-sessions", + prefix="claude-sessions", + region="us-east-1", + ) + ) + + # First conversation - create a session + session_id = "user-123-conversation" + + print("--- Starting new session ---") + options = ClaudeAgentOptions( + session_storage=storage, + session_id=session_id, # Specify session ID + ) + + async with ClaudeSDKClient(options=options) as client: + print("User: My name is Alice and I love Python.") + await client.query("My name is Alice and I love Python.") + + async for msg in client.receive_response(): + display_message(msg) + + print("\n--- Resuming session on different server ---") + + # Resume the same session (could be on a different server/container) + options_resume = ClaudeAgentOptions( + session_storage=storage, + session_id=session_id, # Same session ID - will download from cloud + ) + + async with ClaudeSDKClient(options=options_resume) as client: + print("User: What's my name? What language do I love?") + await client.query("What's my name? What language do I love?") + + async for msg in client.receive_response(): + display_message(msg) + + print( + "\nNote: Claude remembers the conversation because it was resumed from cloud storage.\n" + ) + + +async def example_list_sessions(): + """List sessions stored in cloud storage. + + Useful for admin interfaces, debugging, or cleanup operations. + """ + print("=== List Sessions Example ===\n") + + from claude_agent_sdk.session_storage import S3Config, S3SessionStorage + + storage = S3SessionStorage( + S3Config( + bucket="my-claude-sessions", + prefix="claude-sessions", + region="us-east-1", + ) + ) + + # List all sessions + sessions = await storage.list_sessions(limit=10) + + print(f"Found {len(sessions)} sessions:\n") + for session_meta in sessions: + print(f"Session ID: {session_meta.session_id}") + print(f" Size: {session_meta.size_bytes:,} bytes") + print(f" Updated: {session_meta.updated_at}") + print(f" Storage key: {session_meta.storage_key}") + print() + + # List sessions with prefix filter + user_sessions = await storage.list_sessions(prefix="user-123-", limit=5) + print(f"Found {len(user_sessions)} sessions for user-123") + + +async def example_production_tips(): + """Production deployment tips. + + This example demonstrates best practices for production use. + """ + print("=== Production Tips ===\n") + + print("Production configuration example:") + print() + print("from claude_agent_sdk.session_storage import S3SessionStorage, S3Config") + print() + print("storage = S3SessionStorage(") + print(" S3Config(") + print(" bucket='prod-claude-sessions',") + print(" prefix='claude/v1', # Version your storage structure") + print(" region='us-east-1',") + print(" ),") + print(" max_retries=3, # Retry failed operations") + print(" retry_delay=1.0, # Base delay between retries (exponential backoff)") + print(")") + print() + print("Production best practices:") + print() + print("1. LATENCY WARNING:") + print(" - S3/GCS operations add 50-500ms+ per operation") + print(" - For high-throughput, use caching (see session_storage_cached.py)") + print() + print("2. CREDENTIALS:") + print(" - Use IAM roles in AWS (no hardcoded credentials)") + print(" - Use workload identity in GCP") + print(" - Use environment variables for keys") + print() + print("3. BUCKET CONFIGURATION:") + print(" - Enable versioning for data safety") + print(" - Set lifecycle policies to archive/delete old sessions") + print(" - Configure CORS if accessing from browser") + print() + print("4. MONITORING:") + print(" - Track upload/download latencies") + print(" - Monitor storage costs") + print(" - Alert on error rates") + print() + print("5. SESSION IDs:") + print(" - Use meaningful IDs: user-{user_id}-{timestamp}") + print(" - Include tenant/org ID for multi-tenant apps") + print(" - Avoid PII in session IDs (stored in S3 key)") + print() + + +async def main(): + """Run all examples with error handling.""" + examples = { + "s3_basic": ( + example_s3_basic, + "Basic S3 usage with AWS", + ), + "digitalocean": ( + example_digitalocean_spaces, + "DigitalOcean Spaces (S3-compatible)", + ), + "cloudflare_r2": ( + example_cloudflare_r2, + "Cloudflare R2 (S3-compatible, zero egress)", + ), + "minio": ( + example_minio, + "MinIO (self-hosted S3-compatible)", + ), + "gcs_basic": ( + example_gcs_basic, + "Basic GCS usage", + ), + "gcs_credentials": ( + example_gcs_with_credentials, + "GCS with service account", + ), + "session_resume": ( + example_session_resume, + "Resume session from cloud storage", + ), + "list_sessions": ( + example_list_sessions, + "List and inspect stored sessions", + ), + "production": ( + example_production_tips, + "Production deployment best practices", + ), + } + + print("Claude Agent SDK - Session Storage Examples") + print("=" * 50) + print() + print("Available examples:") + for name, (_, description) in examples.items(): + print(f" {name:20} - {description}") + print() + print( + "Note: These examples demonstrate the API without requiring actual cloud credentials." + ) + print( + " In real usage, ensure credentials are configured via environment variables or IAM roles." + ) + print() + print("=" * 50) + print() + + # Run non-network examples + await example_production_tips() + + +if __name__ == "__main__": + # Set up basic logging to see what's happening + import logging + + logging.basicConfig( + level=logging.INFO, format="%(levelname)s - %(name)s - %(message)s" + ) + + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index bfce3066..a8b003d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,18 @@ dev = [ "mypy>=1.0.0", "ruff>=0.1.0", ] +s3 = [ + "aiobotocore>=2.5.0", + "types-aiobotocore[s3]>=2.5.0", +] +gcs = [ + "google-cloud-storage>=2.0.0", +] +cloud = [ + "aiobotocore>=2.5.0", + "types-aiobotocore[s3]>=2.5.0", + "google-cloud-storage>=2.0.0", +] [project.urls] Homepage = "https://github.com/anthropics/claude-agent-sdk-python" diff --git a/src/claude_agent_sdk/__init__.py b/src/claude_agent_sdk/__init__.py index 4898bc0b..c3e43a2c 100644 --- a/src/claude_agent_sdk/__init__.py +++ b/src/claude_agent_sdk/__init__.py @@ -10,6 +10,7 @@ CLIJSONDecodeError, CLINotFoundError, ProcessError, + SessionStorageError, ) from ._internal.transport import Transport from ._version import __version__ @@ -362,4 +363,5 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any: "CLINotFoundError", "ProcessError", "CLIJSONDecodeError", + "SessionStorageError", ] diff --git a/src/claude_agent_sdk/_errors.py b/src/claude_agent_sdk/_errors.py index c86bf235..f8da01ab 100644 --- a/src/claude_agent_sdk/_errors.py +++ b/src/claude_agent_sdk/_errors.py @@ -54,3 +54,45 @@ class MessageParseError(ClaudeSDKError): def __init__(self, message: str, data: dict[str, Any] | None = None): self.data = data super().__init__(message) + + +class SessionStorageError(ClaudeSDKError): + """Raised when session storage operations fail. + + This error is raised for cloud storage failures such as upload/download + errors, permission issues, or network problems. + + Attributes: + session_id: The session ID involved in the failed operation. + operation: The operation that failed (upload, download, delete, etc.). + original_error: The underlying exception that caused this error. + + Example: + >>> try: + ... await storage.upload_transcript("session-123", "/tmp/transcript.jsonl") + ... except SessionStorageError as e: + ... print(f"Failed to upload session {e.session_id}: {e}") + ... if e.original_error: + ... print(f"Caused by: {e.original_error}") + """ + + def __init__( + self, + message: str, + session_id: str | None = None, + operation: str | None = None, + original_error: Exception | None = None, + ): + self.session_id = session_id + self.operation = operation + self.original_error = original_error + + parts = [message] + if session_id: + parts.append(f"session: {session_id}") + if operation: + parts.append(f"operation: {operation}") + if original_error: + parts.append(f"caused by: {original_error}") + + super().__init__(" | ".join(parts)) diff --git a/src/claude_agent_sdk/client.py b/src/claude_agent_sdk/client.py index 18ab818d..899c8922 100644 --- a/src/claude_agent_sdk/client.py +++ b/src/claude_agent_sdk/client.py @@ -10,6 +10,19 @@ from ._errors import CLIConnectionError from .types import ClaudeAgentOptions, HookEvent, HookMatcher, Message, ResultMessage +# Lazy import to avoid circular dependency +_SessionSyncManager = None + + +def _get_sync_manager_class() -> type: + """Lazy load SessionSyncManager to avoid circular imports.""" + global _SessionSyncManager + if _SessionSyncManager is None: + from .session_storage._sync import SessionSyncManager + + _SessionSyncManager = SessionSyncManager + return _SessionSyncManager + class ClaudeSDKClient: """ @@ -64,6 +77,7 @@ def __init__( self._custom_transport = transport self._transport: Transport | None = None self._query: Any | None = None + self._sync_manager: Any | None = None os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client" def _convert_hooks_to_internal_format( @@ -123,6 +137,25 @@ async def _empty_stream() -> AsyncIterator[dict[str, Any]]: else: options = self.options + # Set up session storage if configured + if options.session_storage is not None: + sync_manager_cls = _get_sync_manager_class() + self._sync_manager = sync_manager_cls( + storage=options.session_storage, + transcript_dir=options.transcript_dir, + ) + + # If resuming a session, download from cloud storage first + if options.resume: + await self._sync_manager.prepare_session(options.resume) + + # Add Stop hook for automatic upload + stop_hook = self._sync_manager.create_stop_hook() + existing_hooks = options.hooks or {} + stop_matchers = list(existing_hooks.get("Stop", [])) + stop_matchers.append(HookMatcher(matcher=None, hooks=[stop_hook])) + options = replace(options, hooks={**existing_hooks, "Stop": stop_matchers}) + # Use provided custom transport or create subprocess transport if self._custom_transport: self._transport = self._custom_transport diff --git a/src/claude_agent_sdk/session_storage/__init__.py b/src/claude_agent_sdk/session_storage/__init__.py new file mode 100644 index 00000000..834f872a --- /dev/null +++ b/src/claude_agent_sdk/session_storage/__init__.py @@ -0,0 +1,86 @@ +"""Session storage backends for cloud persistence. + +This module provides abstractions and implementations for storing session +transcripts in cloud storage, enabling: + +- Horizontal scaling across multiple servers with shared sessions +- Support for ephemeral filesystems (containers, serverless) +- Extensible architecture for custom backends + +WARNING: Direct cloud storage operations add latency (50-500ms+ per operation). +For production at scale, consider wrapping implementations with a caching layer. + +Example: + Basic usage with S3: + + >>> from claude_agent_sdk import ClaudeAgentOptions + >>> from claude_agent_sdk.session_storage import S3SessionStorage, S3Config + >>> + >>> storage = S3SessionStorage(S3Config( + ... bucket="my-sessions", + ... prefix="claude", + ... region="us-east-1", + ... )) + >>> + >>> options = ClaudeAgentOptions(session_storage=storage) + + Custom caching wrapper (for production scale): + + >>> class CachedStorage: + ... def __init__(self, backend, cache): + ... self.backend = backend + ... self.cache = cache + ... + ... async def download_transcript(self, session_id, local_path): + ... if await self.cache.has(session_id): + ... return await self.cache.get(session_id, local_path) + ... return await self.backend.download_transcript(session_id, local_path) + +Available backends: + - S3SessionStorage: AWS S3 (requires: pip install claude-agent-sdk[s3]) + - GCSSessionStorage: Google Cloud Storage (requires: pip install claude-agent-sdk[gcs]) + +See Also: + - SessionStorage: Protocol for implementing custom backends + - BaseSessionStorage: Base class with retry logic +""" + +from ._base import BaseSessionStorage +from ._protocol import SessionMetadata, SessionStorage + +__all__ = [ + # Protocol and base class + "SessionStorage", + "SessionMetadata", + "BaseSessionStorage", +] + + +# Lazy imports for optional cloud dependencies +def __getattr__(name: str) -> object: + """Lazy import cloud storage implementations. + + This allows importing the main module without requiring boto3 or + google-cloud-storage to be installed. + """ + if name == "S3SessionStorage": + from ._s3 import S3SessionStorage + + return S3SessionStorage + if name == "S3Config": + from ._s3 import S3Config + + return S3Config + if name == "GCSSessionStorage": + from ._gcs import GCSSessionStorage + + return GCSSessionStorage + if name == "GCSConfig": + from ._gcs import GCSConfig + + return GCSConfig + if name == "SessionSyncManager": + from ._sync import SessionSyncManager + + return SessionSyncManager + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/claude_agent_sdk/session_storage/_base.py b/src/claude_agent_sdk/session_storage/_base.py new file mode 100644 index 00000000..62ad74ad --- /dev/null +++ b/src/claude_agent_sdk/session_storage/_base.py @@ -0,0 +1,333 @@ +"""Base class for session storage implementations.""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from pathlib import Path + +import anyio + +from ._protocol import SessionMetadata + +logger = logging.getLogger(__name__) + + +class BaseSessionStorage(ABC): + """Abstract base class for session storage with common functionality. + + Provides shared logic like path normalization, key generation, and retry handling. + Subclass this for concrete implementations (S3, GCS, etc.). + + WARNING: Direct cloud storage operations add latency (50-500ms+ per operation). + For production at scale, consider wrapping with a caching layer. + + Attributes: + prefix: Storage key prefix for organizing sessions. + max_retries: Maximum retry attempts for failed operations. + retry_delay: Base delay between retries (exponential backoff applied). + + Example: + Implementing a custom backend: + + >>> class MyCloudStorage(BaseSessionStorage): + ... async def _do_upload(self, key: str, local_path: Path) -> None: + ... # Your upload logic here + ... pass + ... + ... async def _do_download(self, key: str, local_path: Path) -> bool: + ... # Your download logic here + ... return True + ... + ... # ... implement other abstract methods + """ + + def __init__( + self, + prefix: str = "claude-sessions", + max_retries: int = 3, + retry_delay: float = 1.0, + ) -> None: + """Initialize base session storage. + + Args: + prefix: Storage key prefix for organizing sessions. + max_retries: Maximum retry attempts for failed operations. + retry_delay: Base delay in seconds between retries. + """ + self.prefix = prefix.rstrip("/") + self.max_retries = max_retries + self.retry_delay = retry_delay + + def _get_key(self, session_id: str) -> str: + """Generate storage key for a session. + + Sanitizes session_id to prevent path traversal attacks. + + Args: + session_id: The session identifier. + + Returns: + Full storage key including prefix. + """ + # Sanitize session_id to prevent path traversal + safe_id = session_id.replace("/", "_").replace("\\", "_").replace("..", "_") + return f"{self.prefix}/{safe_id}/transcript.jsonl" + + def _extract_session_id(self, key: str) -> str | None: + """Extract session_id from a storage key. + + Args: + key: Full storage key. + + Returns: + Session ID or None if key doesn't match expected format. + """ + if not key.startswith(self.prefix + "/"): + return None + remainder = key[len(self.prefix) + 1 :] + parts = remainder.split("/") + if len(parts) >= 1: + return parts[0] + return None + + @abstractmethod + async def _do_upload(self, key: str, local_path: Path) -> None: + """Backend-specific upload implementation. + + Args: + key: Storage key to upload to. + local_path: Local file to upload. + + Raises: + Exception: On upload failure. + """ + ... + + @abstractmethod + async def _do_download(self, key: str, local_path: Path) -> bool: + """Backend-specific download implementation. + + Args: + key: Storage key to download from. + local_path: Local path to save file. + + Returns: + True if downloaded, False if not found. + + Raises: + Exception: On download failure (other than not found). + """ + ... + + @abstractmethod + async def _do_exists(self, key: str) -> bool: + """Backend-specific existence check. + + Args: + key: Storage key to check. + + Returns: + True if exists. + """ + ... + + @abstractmethod + async def _do_delete(self, key: str) -> bool: + """Backend-specific delete implementation. + + Args: + key: Storage key to delete. + + Returns: + True if deleted, False if not found. + """ + ... + + @abstractmethod + async def _do_list(self, prefix: str, limit: int) -> list[SessionMetadata]: + """Backend-specific list implementation. + + Args: + prefix: Full prefix to list under. + limit: Maximum items to return. + + Returns: + List of session metadata. + """ + ... + + @abstractmethod + async def _do_get_metadata(self, key: str) -> SessionMetadata | None: + """Backend-specific metadata retrieval. + + Args: + key: Storage key to get metadata for. + + Returns: + Metadata or None if not found. + """ + ... + + async def upload_transcript( + self, + session_id: str, + local_path: Path | str, + ) -> str: + """Upload a local transcript file to cloud storage. + + Includes retry logic with exponential backoff. + + Args: + session_id: The session identifier. + local_path: Path to the local transcript file. + + Returns: + The cloud storage key for the uploaded file. + + Raises: + SessionStorageError: If upload fails after all retries. + """ + from .._errors import SessionStorageError + + key = self._get_key(session_id) + path = Path(local_path) + + if not path.exists(): + raise SessionStorageError( + f"Local transcript not found: {path}", + session_id=session_id, + operation="upload", + ) + + last_error: Exception | None = None + for attempt in range(self.max_retries): + try: + await self._do_upload(key, path) + logger.debug(f"Uploaded session {session_id} to {key}") + return key + except Exception as e: + last_error = e + if attempt < self.max_retries - 1: + delay = self.retry_delay * (2**attempt) + logger.warning( + f"Upload attempt {attempt + 1} failed for {session_id}, " + f"retrying in {delay}s: {e}" + ) + await anyio.sleep(delay) + + raise SessionStorageError( + f"Upload failed after {self.max_retries} attempts", + session_id=session_id, + operation="upload", + original_error=last_error, + ) + + async def download_transcript( + self, + session_id: str, + local_path: Path | str, + ) -> bool: + """Download a transcript from cloud storage to local path. + + Includes retry logic with exponential backoff. + + Args: + session_id: The session identifier. + local_path: Where to save the downloaded file. + + Returns: + True if download succeeded, False if session not found. + + Raises: + SessionStorageError: If download fails after all retries. + """ + from .._errors import SessionStorageError + + key = self._get_key(session_id) + path = Path(local_path) + + # Ensure parent directory exists + path.parent.mkdir(parents=True, exist_ok=True) + + last_error: Exception | None = None + for attempt in range(self.max_retries): + try: + result = await self._do_download(key, path) + if result: + logger.debug(f"Downloaded session {session_id} to {path}") + else: + logger.debug(f"Session {session_id} not found in storage") + return result + except Exception as e: + last_error = e + if attempt < self.max_retries - 1: + delay = self.retry_delay * (2**attempt) + logger.warning( + f"Download attempt {attempt + 1} failed for {session_id}, " + f"retrying in {delay}s: {e}" + ) + await anyio.sleep(delay) + + raise SessionStorageError( + f"Download failed after {self.max_retries} attempts", + session_id=session_id, + operation="download", + original_error=last_error, + ) + + async def exists(self, session_id: str) -> bool: + """Check if a session exists in cloud storage. + + Args: + session_id: The session identifier. + + Returns: + True if session exists in storage. + """ + key = self._get_key(session_id) + return await self._do_exists(key) + + async def delete(self, session_id: str) -> bool: + """Delete a session from cloud storage. + + Args: + session_id: The session identifier. + + Returns: + True if deleted, False if not found. + """ + key = self._get_key(session_id) + result = await self._do_delete(key) + if result: + logger.debug(f"Deleted session {session_id}") + return result + + async def list_sessions( + self, + prefix: str | None = None, + limit: int = 100, + ) -> list[SessionMetadata]: + """List sessions in cloud storage. + + Args: + prefix: Optional prefix filter for session IDs. + limit: Maximum number of sessions to return. + + Returns: + List of session metadata. + """ + full_prefix = f"{self.prefix}/{prefix}" if prefix else self.prefix + return await self._do_list(full_prefix, limit) + + async def get_metadata(self, session_id: str) -> SessionMetadata | None: + """Get metadata for a session. + + Args: + session_id: The session identifier. + + Returns: + Session metadata or None if not found. + """ + key = self._get_key(session_id) + return await self._do_get_metadata(key) diff --git a/src/claude_agent_sdk/session_storage/_gcs.py b/src/claude_agent_sdk/session_storage/_gcs.py new file mode 100644 index 00000000..ba994e6e --- /dev/null +++ b/src/claude_agent_sdk/session_storage/_gcs.py @@ -0,0 +1,365 @@ +"""Google Cloud Storage session storage implementation.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +import anyio + +from ._base import BaseSessionStorage +from ._protocol import SessionMetadata + +if TYPE_CHECKING: + from google.cloud.storage import Bucket, Client # type: ignore[import-not-found] + +logger = logging.getLogger(__name__) + + +@dataclass +class GCSConfig: + """Configuration for Google Cloud Storage session storage. + + Attributes: + bucket: GCS bucket name (required). + prefix: Key prefix for organizing sessions. Defaults to "claude-sessions". + project: GCP project ID. If None, uses default from credentials. + credentials_path: Path to service account JSON key file. If None, uses + Application Default Credentials (ADC). + + Example: + Using service account credentials: + + >>> config = GCSConfig( + ... bucket="my-sessions", + ... prefix="claude-prod", + ... project="my-gcp-project", + ... credentials_path="/path/to/service-account.json" + ... ) + + Using Application Default Credentials: + + >>> config = GCSConfig(bucket="my-sessions") + """ + + bucket: str + prefix: str = "claude-sessions" + project: str | None = None + credentials_path: str | None = None + + +class GCSSessionStorage(BaseSessionStorage): + """Google Cloud Storage session storage implementation. + + WARNING: GCS operations add latency (50-500ms+ per operation). For production + at scale, consider wrapping with a caching layer (Redis, local LRU, etc.). + + This implementation uses the google-cloud-storage library and wraps synchronous + operations with anyio.to_thread.run_sync() for async compatibility. + + Authentication: + The client authenticates using: + 1. Service account JSON file if credentials_path is provided + 2. Application Default Credentials (ADC) otherwise: + - GOOGLE_APPLICATION_CREDENTIALS environment variable + - gcloud CLI credentials + - Compute Engine/GKE/Cloud Run service account + + Installation: + This implementation requires google-cloud-storage: + + >>> pip install claude-agent-sdk[gcs] + + Or install directly: + + >>> pip install google-cloud-storage + + Example: + Basic usage: + + >>> from claude_agent_sdk.session_storage import GCSSessionStorage, GCSConfig + >>> storage = GCSSessionStorage(GCSConfig( + ... bucket="my-sessions", + ... prefix="claude", + ... project="my-gcp-project" + ... )) + >>> # Upload a transcript + >>> key = await storage.upload_transcript( + ... "session-123", + ... "/tmp/transcript.jsonl" + ... ) + >>> print(f"Uploaded to: {key}") + 'claude/session-123/transcript.jsonl' + >>> + >>> # Download a transcript + >>> success = await storage.download_transcript( + ... "session-123", + ... "/tmp/restored.jsonl" + ... ) + >>> + >>> # List sessions + >>> sessions = await storage.list_sessions(prefix="prod-", limit=50) + >>> for meta in sessions: + ... print(f"{meta.session_id}: {meta.size_bytes} bytes") + + With service account credentials: + + >>> storage = GCSSessionStorage(GCSConfig( + ... bucket="my-sessions", + ... credentials_path="/path/to/service-account.json" + ... )) + + Attributes: + config: GCS configuration. + max_retries: Maximum retry attempts (inherited from BaseSessionStorage). + retry_delay: Base delay between retries (inherited from BaseSessionStorage). + """ + + def __init__( + self, + config: GCSConfig, + max_retries: int = 3, + retry_delay: float = 1.0, + ) -> None: + """Initialize GCS session storage. + + Args: + config: GCS configuration. + max_retries: Maximum retry attempts for failed operations. + retry_delay: Base delay in seconds between retries. + + Raises: + ImportError: If google-cloud-storage is not installed. + """ + super().__init__( + prefix=config.prefix, + max_retries=max_retries, + retry_delay=retry_delay, + ) + self.config = config + self._client: Client | None = None + self._bucket: Bucket | None = None + + # Validate import on init + self._ensure_gcs_available() + + def _ensure_gcs_available(self) -> None: + """Ensure google-cloud-storage is installed. + + Raises: + ImportError: If google-cloud-storage is not installed. + """ + try: + import google.cloud.storage # type: ignore[import-not-found,import-untyped,unused-ignore] # noqa: F401 + except ImportError as e: + raise ImportError( + "google-cloud-storage is required for GCSSessionStorage. " + "Install it with: pip install claude-agent-sdk[gcs] " + "or: pip install google-cloud-storage" + ) from e + + def _get_client(self) -> Client: + """Get or create the GCS client (lazy initialization). + + Returns: + Initialized GCS client. + """ + if self._client is None: + # ruff: noqa: I001 + from google.cloud import storage # type: ignore[import-not-found,import-untyped,unused-ignore] + from google.oauth2 import service_account # type: ignore[import-not-found,import-untyped,unused-ignore] + + if self.config.credentials_path: + # Use service account credentials from file + credentials = service_account.Credentials.from_service_account_file( + self.config.credentials_path + ) + self._client = storage.Client( + project=self.config.project, + credentials=credentials, + ) + else: + # Use Application Default Credentials + self._client = storage.Client(project=self.config.project) + + return self._client + + def _get_bucket(self) -> Bucket: + """Get or create the GCS bucket reference (lazy initialization). + + Returns: + GCS bucket object. + """ + if self._bucket is None: + client = self._get_client() + self._bucket = client.bucket(self.config.bucket) + return self._bucket + + async def _do_upload(self, key: str, local_path: Path) -> None: + """Upload a file to GCS. + + Args: + key: GCS object key. + local_path: Local file to upload. + + Raises: + Exception: On upload failure. + """ + + def _sync_upload() -> None: + bucket = self._get_bucket() + blob = bucket.blob(key) + blob.upload_from_filename(str(local_path)) + + await anyio.to_thread.run_sync(_sync_upload) + + async def _do_download(self, key: str, local_path: Path) -> bool: + """Download a file from GCS. + + Args: + key: GCS object key. + local_path: Local path to save file. + + Returns: + True if downloaded, False if not found. + + Raises: + Exception: On download failure (other than not found). + """ + + def _sync_download() -> bool: + from google.api_core.exceptions import NotFound # type: ignore[import-not-found,import-untyped,unused-ignore] # noqa: I001 + + bucket = self._get_bucket() + blob = bucket.blob(key) + + try: + blob.download_to_filename(str(local_path)) + return True + except NotFound: + return False + + return await anyio.to_thread.run_sync(_sync_download) + + async def _do_exists(self, key: str) -> bool: + """Check if an object exists in GCS. + + Args: + key: GCS object key. + + Returns: + True if exists. + """ + + def _sync_exists() -> bool: + bucket = self._get_bucket() + blob = bucket.blob(key) + return bool(blob.exists()) + + return await anyio.to_thread.run_sync(_sync_exists) + + async def _do_delete(self, key: str) -> bool: + """Delete an object from GCS. + + Args: + key: GCS object key. + + Returns: + True if deleted, False if not found. + """ + + def _sync_delete() -> bool: + from google.api_core.exceptions import NotFound # type: ignore[import-not-found,import-untyped,unused-ignore] # noqa: I001 + + bucket = self._get_bucket() + blob = bucket.blob(key) + + try: + blob.delete() + return True + except NotFound: + return False + + return await anyio.to_thread.run_sync(_sync_delete) + + async def _do_list(self, prefix: str, limit: int) -> list[SessionMetadata]: + """List objects in GCS with given prefix. + + Args: + prefix: GCS key prefix. + limit: Maximum number of items to return. + + Returns: + List of session metadata. + """ + + def _sync_list() -> list[SessionMetadata]: + bucket = self._get_bucket() + blobs = bucket.list_blobs(prefix=prefix, max_results=limit) + + results: list[SessionMetadata] = [] + for blob in blobs: + # Extract session_id from key + session_id = self._extract_session_id(blob.name) + if not session_id: + continue + + # Convert timestamps + created_at = blob.time_created.timestamp() if blob.time_created else 0.0 + updated_at = blob.updated.timestamp() if blob.updated else created_at + + metadata = SessionMetadata( + session_id=session_id, + created_at=created_at, + updated_at=updated_at, + size_bytes=blob.size or 0, + storage_key=blob.name, + ) + results.append(metadata) + + return results + + return await anyio.to_thread.run_sync(_sync_list) + + async def _do_get_metadata(self, key: str) -> SessionMetadata | None: + """Get metadata for a specific object in GCS. + + Args: + key: GCS object key. + + Returns: + Session metadata or None if not found. + """ + + def _sync_get_metadata() -> SessionMetadata | None: + from google.api_core.exceptions import NotFound # type: ignore[import-not-found,import-untyped,unused-ignore] # noqa: I001 + + bucket = self._get_bucket() + blob = bucket.blob(key) + + try: + # Reload to fetch metadata + blob.reload() + except NotFound: + return None + + # Extract session_id from key + session_id = self._extract_session_id(blob.name) + if not session_id: + return None + + # Convert timestamps + created_at = blob.time_created.timestamp() if blob.time_created else 0.0 + updated_at = blob.updated.timestamp() if blob.updated else created_at + + return SessionMetadata( + session_id=session_id, + created_at=created_at, + updated_at=updated_at, + size_bytes=blob.size or 0, + storage_key=blob.name, + ) + + return await anyio.to_thread.run_sync(_sync_get_metadata) diff --git a/src/claude_agent_sdk/session_storage/_protocol.py b/src/claude_agent_sdk/session_storage/_protocol.py new file mode 100644 index 00000000..1fcd2823 --- /dev/null +++ b/src/claude_agent_sdk/session_storage/_protocol.py @@ -0,0 +1,151 @@ +"""Session storage protocol definition.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Protocol, runtime_checkable + + +@dataclass +class SessionMetadata: + """Metadata about a stored session. + + Attributes: + session_id: Unique identifier for the session. + created_at: Unix timestamp when session was first stored. + updated_at: Unix timestamp when session was last updated. + size_bytes: Size of the transcript in bytes. + storage_key: Backend-specific storage key/path. + """ + + session_id: str + created_at: float + updated_at: float + size_bytes: int + storage_key: str + + +@runtime_checkable +class SessionStorage(Protocol): + """Protocol for session storage backends. + + Implementations provide async methods for uploading, downloading, + and managing session transcripts in cloud storage. + + The Claude Code CLI writes transcripts to local paths. This protocol + abstracts syncing those local files to/from cloud storage for persistence + across ephemeral environments and horizontal scaling. + + WARNING: Direct cloud storage operations add latency (50-500ms+ per operation). + For production at scale, consider wrapping implementations with a caching layer + (Redis, local LRU cache, etc.). + + Example: + Basic usage with S3: + + >>> from claude_agent_sdk.session_storage import S3SessionStorage, S3Config + >>> storage = S3SessionStorage(S3Config(bucket="my-bucket")) + >>> await storage.upload_transcript("session-123", "/tmp/transcript.jsonl") + 'claude-sessions/session-123/transcript.jsonl' + + Custom implementation: + + >>> class MyStorage: + ... async def upload_transcript(self, session_id, local_path): + ... # Upload to your backend + ... return f"my-backend/{session_id}" + ... + ... async def download_transcript(self, session_id, local_path): + ... # Download from your backend + ... return True + ... + ... # ... implement other methods + """ + + async def upload_transcript( + self, + session_id: str, + local_path: Path | str, + ) -> str: + """Upload a local transcript file to cloud storage. + + Args: + session_id: The session identifier. + local_path: Path to the local transcript file. + + Returns: + The cloud storage key/URL for the uploaded file. + + Raises: + SessionStorageError: If upload fails. + """ + ... + + async def download_transcript( + self, + session_id: str, + local_path: Path | str, + ) -> bool: + """Download a transcript from cloud storage to local path. + + Args: + session_id: The session identifier. + local_path: Where to save the downloaded file. + + Returns: + True if download succeeded, False if session not found. + + Raises: + SessionStorageError: If download fails for reasons other than not found. + """ + ... + + async def exists(self, session_id: str) -> bool: + """Check if a session exists in cloud storage. + + Args: + session_id: The session identifier. + + Returns: + True if session exists in storage. + """ + ... + + async def delete(self, session_id: str) -> bool: + """Delete a session from cloud storage. + + Args: + session_id: The session identifier. + + Returns: + True if deleted, False if not found. + """ + ... + + async def list_sessions( + self, + prefix: str | None = None, + limit: int = 100, + ) -> list[SessionMetadata]: + """List sessions in cloud storage. + + Args: + prefix: Optional prefix filter for session IDs. + limit: Maximum number of sessions to return. + + Returns: + List of session metadata. + """ + ... + + async def get_metadata(self, session_id: str) -> SessionMetadata | None: + """Get metadata for a session. + + Args: + session_id: The session identifier. + + Returns: + Session metadata or None if not found. + """ + ... diff --git a/src/claude_agent_sdk/session_storage/_s3.py b/src/claude_agent_sdk/session_storage/_s3.py new file mode 100644 index 00000000..9534f57d --- /dev/null +++ b/src/claude_agent_sdk/session_storage/_s3.py @@ -0,0 +1,414 @@ +"""S3 session storage implementation. + +This module provides an S3-backed session storage implementation using +aiobotocore for async operations. + +WARNING: S3 operations add latency (50-500ms per operation). For production +at scale, consider wrapping with a caching layer (Redis, local LRU, etc.). + +Installation: + pip install claude-agent-sdk[s3] + + Or install aiobotocore directly: + pip install aiobotocore + +Example: + Basic AWS S3 usage: + + >>> from claude_agent_sdk.session_storage import S3SessionStorage, S3Config + >>> + >>> storage = S3SessionStorage(S3Config( + ... bucket="my-sessions", + ... prefix="claude-sessions", + ... region="us-east-1", + ... )) + >>> await storage.upload_transcript("session-123", "/tmp/transcript.jsonl") + 'claude-sessions/session-123/transcript.jsonl' + + With explicit credentials: + + >>> storage = S3SessionStorage(S3Config( + ... bucket="my-sessions", + ... aws_access_key_id="AKIAIOSFODNN7EXAMPLE", + ... aws_secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + ... region="us-east-1", + ... )) + + S3-compatible services (MinIO, Cloudflare R2, DigitalOcean Spaces): + + >>> # MinIO + >>> storage = S3SessionStorage(S3Config( + ... bucket="my-sessions", + ... endpoint_url="https://minio.example.com", + ... aws_access_key_id="minioadmin", + ... aws_secret_access_key="minioadmin", + ... )) + >>> + >>> # Cloudflare R2 + >>> storage = S3SessionStorage(S3Config( + ... bucket="my-sessions", + ... endpoint_url="https://account-id.r2.cloudflarestorage.com", + ... aws_access_key_id="your-r2-access-key", + ... aws_secret_access_key="your-r2-secret-key", + ... )) + +Notes: + - If credentials not provided, uses AWS credential chain: + environment vars, IAM roles, shared credentials file + - The client is lazy-initialized on first use + - Call close() to clean up resources when done + - S3-compatible services work with endpoint_url parameter +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +from ._base import BaseSessionStorage +from ._protocol import SessionMetadata + +if TYPE_CHECKING: + from types_aiobotocore_s3 import S3Client # type: ignore[import-not-found] + +logger = logging.getLogger(__name__) + +# Check for aiobotocore availability +try: + import aiobotocore.session # type: ignore[import-not-found] + from botocore.exceptions import ClientError # type: ignore[import-not-found] + + _HAS_AIOBOTOCORE = True +except ImportError: + _HAS_AIOBOTOCORE = False + + +@dataclass +class S3Config: + """Configuration for S3 session storage. + + Attributes: + bucket: S3 bucket name (required). + prefix: Key prefix for organizing sessions (default: "claude-sessions"). + region: AWS region (optional, uses default if not set). + endpoint_url: Custom S3 endpoint for S3-compatible services (optional). + Examples: MinIO, Cloudflare R2, DigitalOcean Spaces. + aws_access_key_id: AWS access key (optional, uses credential chain if not set). + aws_secret_access_key: AWS secret key (optional, uses credential chain if not set). + + Example: + AWS S3: + >>> config = S3Config(bucket="my-bucket", region="us-east-1") + + MinIO: + >>> config = S3Config( + ... bucket="my-bucket", + ... endpoint_url="https://minio.example.com", + ... aws_access_key_id="admin", + ... aws_secret_access_key="password", + ... ) + """ + + bucket: str + prefix: str = "claude-sessions" + region: str | None = None + endpoint_url: str | None = None + aws_access_key_id: str | None = None + aws_secret_access_key: str | None = None + + +class S3SessionStorage(BaseSessionStorage): + """S3-backed session storage with async operations. + + Uses aiobotocore for efficient async S3 operations. Supports both AWS S3 + and S3-compatible services (MinIO, Cloudflare R2, etc.). + + WARNING: S3 operations add latency (50-500ms per operation). For production + at scale, consider wrapping with a caching layer. + + The S3 client is lazy-initialized on first use to avoid connection overhead + during initialization. + + Attributes: + config: S3 configuration. + max_retries: Maximum retry attempts (inherited from BaseSessionStorage). + retry_delay: Base retry delay in seconds (inherited from BaseSessionStorage). + + Example: + >>> from claude_agent_sdk.session_storage import S3SessionStorage, S3Config + >>> + >>> config = S3Config(bucket="my-sessions", region="us-east-1") + >>> storage = S3SessionStorage(config) + >>> + >>> # Upload a session + >>> key = await storage.upload_transcript("session-123", "/tmp/transcript.jsonl") + >>> + >>> # Download a session + >>> success = await storage.download_transcript("session-123", "/tmp/downloaded.jsonl") + >>> + >>> # List sessions + >>> sessions = await storage.list_sessions(limit=10) + >>> + >>> # Clean up + >>> await storage.close() + + Note: + Call close() when done to properly clean up the S3 client connection pool. + """ + + def __init__( + self, + config: S3Config, + max_retries: int = 3, + retry_delay: float = 1.0, + ) -> None: + """Initialize S3 session storage. + + Args: + config: S3 configuration. + max_retries: Maximum retry attempts for failed operations. + retry_delay: Base delay in seconds between retries. + + Raises: + ImportError: If aiobotocore is not installed. + """ + if not _HAS_AIOBOTOCORE: + raise ImportError( + "aiobotocore is required for S3 session storage. " + "Install with: pip install claude-agent-sdk[s3]" + ) + + super().__init__( + prefix=config.prefix, max_retries=max_retries, retry_delay=retry_delay + ) + self.config = config + self._client: S3Client | None = None + self._session: aiobotocore.session.AioSession | None = None + self._client_context: object | None = None + + async def _get_client(self) -> S3Client: + """Get or create the S3 client. + + Lazy-initializes the client on first use. + + Returns: + S3 client instance. + """ + if self._client is not None: + return self._client + + # Create session + self._session = aiobotocore.session.get_session() + + # Build client config + client_kwargs = {"service_name": "s3"} + + if self.config.region: + client_kwargs["region_name"] = self.config.region + if self.config.endpoint_url: + client_kwargs["endpoint_url"] = self.config.endpoint_url + if self.config.aws_access_key_id: + client_kwargs["aws_access_key_id"] = self.config.aws_access_key_id + if self.config.aws_secret_access_key: + client_kwargs["aws_secret_access_key"] = self.config.aws_secret_access_key + + # Create client context manager + self._client_context = self._session.create_client(**client_kwargs) + self._client = await self._client_context.__aenter__() # type: ignore + + logger.debug(f"Initialized S3 client for bucket: {self.config.bucket}") + return self._client + + async def close(self) -> None: + """Close the S3 client and clean up resources. + + Should be called when done using the storage to properly close + the connection pool. + + Example: + >>> storage = S3SessionStorage(config) + >>> try: + ... await storage.upload_transcript("session-123", "/tmp/transcript.jsonl") + ... finally: + ... await storage.close() + """ + if self._client_context is not None and self._client is not None: + await self._client_context.__aexit__(None, None, None) # type: ignore + self._client = None + self._client_context = None + logger.debug("Closed S3 client") + + async def _do_upload(self, key: str, local_path: Path) -> None: + """Upload a file to S3. + + Args: + key: S3 key to upload to. + local_path: Local file to upload. + + Raises: + Exception: On upload failure. + """ + client = await self._get_client() + + with local_path.open("rb") as f: + await client.put_object( + Bucket=self.config.bucket, + Key=key, + Body=f, + ) + + async def _do_download(self, key: str, local_path: Path) -> bool: + """Download a file from S3. + + Args: + key: S3 key to download from. + local_path: Local path to save file. + + Returns: + True if downloaded, False if not found. + + Raises: + Exception: On download failure (other than not found). + """ + client = await self._get_client() + + try: + response = await client.get_object( + Bucket=self.config.bucket, + Key=key, + ) + + async with response["Body"] as stream: + data = await stream.read() + + local_path.write_bytes(data) + return True + + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + return False + raise + + async def _do_exists(self, key: str) -> bool: + """Check if a key exists in S3. + + Args: + key: S3 key to check. + + Returns: + True if exists. + """ + client = await self._get_client() + + try: + await client.head_object( + Bucket=self.config.bucket, + Key=key, + ) + return True + except ClientError as e: + if e.response["Error"]["Code"] in ("404", "NoSuchKey"): + return False + raise + + async def _do_delete(self, key: str) -> bool: + """Delete a key from S3. + + Args: + key: S3 key to delete. + + Returns: + True if deleted, False if not found. + """ + client = await self._get_client() + + # Check if exists first + exists = await self._do_exists(key) + if not exists: + return False + + await client.delete_object( + Bucket=self.config.bucket, + Key=key, + ) + return True + + async def _do_list(self, prefix: str, limit: int) -> list[SessionMetadata]: + """List objects with given prefix. + + Args: + prefix: S3 key prefix to list under. + limit: Maximum items to return. + + Returns: + List of session metadata. + """ + client = await self._get_client() + + # Ensure prefix ends with / for directory-like listing + if not prefix.endswith("/"): + prefix = prefix + "/" + + response = await client.list_objects_v2( + Bucket=self.config.bucket, + Prefix=prefix, + MaxKeys=limit, + ) + + results: list[SessionMetadata] = [] + contents = response.get("Contents", []) + + for obj in contents: + key = obj["Key"] + session_id = self._extract_session_id(key) + + if session_id: + results.append( + SessionMetadata( + session_id=session_id, + created_at=obj["LastModified"].timestamp(), + updated_at=obj["LastModified"].timestamp(), + size_bytes=obj["Size"], + storage_key=key, + ) + ) + + return results + + async def _do_get_metadata(self, key: str) -> SessionMetadata | None: + """Get metadata for a specific key. + + Args: + key: S3 key to get metadata for. + + Returns: + Metadata or None if not found. + """ + client = await self._get_client() + + try: + response = await client.head_object( + Bucket=self.config.bucket, + Key=key, + ) + + session_id = self._extract_session_id(key) + if not session_id: + return None + + last_modified = response["LastModified"].timestamp() + + return SessionMetadata( + session_id=session_id, + created_at=last_modified, + updated_at=last_modified, + size_bytes=response["ContentLength"], + storage_key=key, + ) + + except ClientError as e: + if e.response["Error"]["Code"] in ("404", "NoSuchKey"): + return None + raise diff --git a/src/claude_agent_sdk/session_storage/_sync.py b/src/claude_agent_sdk/session_storage/_sync.py new file mode 100644 index 00000000..3a2d2856 --- /dev/null +++ b/src/claude_agent_sdk/session_storage/_sync.py @@ -0,0 +1,187 @@ +"""Session sync manager for coordinating cloud storage with CLI. + +This module provides the SessionSyncManager which handles: +- Downloading sessions from cloud storage on resume +- Uploading sessions to cloud storage on session end (via Stop hook) +""" + +from __future__ import annotations + +import logging +import tempfile +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from ._protocol import SessionStorage + +logger = logging.getLogger(__name__) + + +class SessionSyncManager: + """Manages transcript synchronization between local filesystem and cloud storage. + + This manager coordinates session persistence by: + 1. Downloading existing sessions from cloud storage when resuming + 2. Uploading session transcripts to cloud storage when sessions end + + The sync manager integrates with ClaudeSDKClient via hooks - it creates + a Stop hook callback that automatically uploads transcripts when sessions end. + + WARNING: Cloud storage operations add latency. For production scale, + consider wrapping your SessionStorage with a caching layer. + + Attributes: + storage: The session storage backend (S3, GCS, or custom). + transcript_dir: Local directory for transcript files. + + Example: + Basic usage (typically handled by ClaudeSDKClient automatically): + + >>> from claude_agent_sdk.session_storage import SessionSyncManager, S3SessionStorage, S3Config + >>> + >>> storage = S3SessionStorage(S3Config(bucket="my-sessions")) + >>> manager = SessionSyncManager(storage) + >>> + >>> # Prepare for session resume (downloads from cloud if exists) + >>> local_path = await manager.prepare_session("session-123") + >>> + >>> # Create Stop hook for automatic upload + >>> hook = manager.create_stop_hook() + + Note: + In most cases, you don't need to use SessionSyncManager directly. + Simply pass `session_storage` to ClaudeAgentOptions and the SDK + handles sync automatically. + """ + + def __init__( + self, + storage: SessionStorage, + transcript_dir: Path | str | None = None, + ) -> None: + """Initialize session sync manager. + + Args: + storage: The session storage backend. + transcript_dir: Directory for local transcripts. If None, uses + a subdirectory in the system temp directory. + """ + self.storage = storage + if transcript_dir is None: + self.transcript_dir = Path(tempfile.gettempdir()) / "claude-sessions" + else: + self.transcript_dir = Path(transcript_dir) + self.transcript_dir.mkdir(parents=True, exist_ok=True) + self._active_sessions: dict[str, Path] = {} + + def get_local_transcript_path(self, session_id: str) -> Path: + """Get or create local path for a session's transcript. + + Args: + session_id: The session identifier. + + Returns: + Path where the transcript will be stored locally. + """ + if session_id not in self._active_sessions: + path = self.transcript_dir / session_id / "transcript.jsonl" + path.parent.mkdir(parents=True, exist_ok=True) + self._active_sessions[session_id] = path + return self._active_sessions[session_id] + + async def prepare_session(self, session_id: str) -> Path: + """Prepare local environment for a session. + + If the session exists in cloud storage, downloads it to local path. + Called before session starts when resuming. + + Args: + session_id: The session identifier. + + Returns: + Local path where transcript will be/is stored. + """ + local_path = self.get_local_transcript_path(session_id) + + # Try to download existing transcript from cloud + if await self.storage.exists(session_id): + logger.info(f"Downloading session {session_id} from cloud storage") + downloaded = await self.storage.download_transcript(session_id, local_path) + if downloaded: + logger.info(f"Session {session_id} downloaded to {local_path}") + else: + logger.warning(f"Session {session_id} exists but download failed") + else: + logger.debug( + f"Session {session_id} not found in cloud storage (new session)" + ) + + return local_path + + async def finalize_session( + self, session_id: str, transcript_path: str | Path + ) -> None: + """Upload session transcript to cloud after session ends. + + Called from Stop hook. + + Args: + session_id: The session identifier. + transcript_path: Path to the transcript file (from hook input). + """ + path = Path(transcript_path) + + if not path.exists(): + logger.warning(f"Transcript not found at {path}, skipping upload") + return + + logger.info(f"Uploading session {session_id} to cloud storage") + try: + key = await self.storage.upload_transcript(session_id, path) + logger.info(f"Session {session_id} uploaded to {key}") + except Exception as e: + # Log but don't fail - we don't want to break session end + logger.error(f"Failed to upload session {session_id}: {e}") + + # Clean up tracking + self._active_sessions.pop(session_id, None) + + def create_stop_hook(self) -> Any: + """Create a Stop hook callback for automatic transcript upload. + + The returned callback uploads the session transcript to cloud storage + when the session ends. Errors are logged but don't fail the session. + + Returns: + HookCallback function for use with HookMatcher. + + Example: + >>> manager = SessionSyncManager(storage) + >>> hook = manager.create_stop_hook() + >>> + >>> # Use with HookMatcher + >>> from claude_agent_sdk import HookMatcher + >>> matcher = HookMatcher(matcher=None, hooks=[hook]) + """ + + async def stop_hook( + input_data: dict[str, Any], + tool_use_id: str | None, + context: dict[str, Any], + ) -> dict[str, Any]: + """Hook callback to upload transcript on session end.""" + session_id = input_data.get("session_id", "") + transcript_path = input_data.get("transcript_path", "") + + if session_id and transcript_path: + try: + await self.finalize_session(session_id, transcript_path) + except Exception as e: + # Log error but don't fail the session + logger.error(f"Stop hook failed for session {session_id}: {e}") + + # Return empty output - don't modify session behavior + return {} + + return stop_hook diff --git a/src/claude_agent_sdk/types.py b/src/claude_agent_sdk/types.py index 9c09345f..1ec295b0 100644 --- a/src/claude_agent_sdk/types.py +++ b/src/claude_agent_sdk/types.py @@ -10,9 +10,12 @@ if TYPE_CHECKING: from mcp.server import Server as McpServer + + from .session_storage import SessionStorage else: # Runtime placeholder for forward reference resolution in Pydantic 2.12+ McpServer = Any + SessionStorage = Any # Permission modes PermissionMode = Literal["default", "acceptEdits", "plan", "bypassPermissions"] @@ -679,6 +682,15 @@ class ClaudeAgentOptions: # using `ClaudeSDKClient.rewind_files()`. enable_file_checkpointing: bool = False + # Session storage backend for cloud persistence. + # Enables horizontal scaling and ephemeral filesystem support. + # WARNING: Cloud storage adds latency (50-500ms per operation). + # For production at scale, consider wrapping with a caching layer. + session_storage: "SessionStorage | None" = None + # Local directory for transcript files when using session_storage. + # Defaults to system temp directory if not specified. + transcript_dir: str | Path | None = None + # SDK Control Protocol class SDKControlInterruptRequest(TypedDict): diff --git a/tests/test_session_storage.py b/tests/test_session_storage.py new file mode 100644 index 00000000..8d2bff03 --- /dev/null +++ b/tests/test_session_storage.py @@ -0,0 +1,1038 @@ +"""Comprehensive unit tests for session storage module.""" + +from pathlib import Path + +import pytest + +from claude_agent_sdk import SessionStorageError +from claude_agent_sdk.session_storage import SessionMetadata, SessionSyncManager +from claude_agent_sdk.session_storage._base import BaseSessionStorage +from claude_agent_sdk.session_storage._protocol import SessionStorage + +# ============================================================================ +# Mock Implementation +# ============================================================================ + + +class MockSessionStorage: + """In-memory session storage for testing. + + Implements the SessionStorage protocol without requiring cloud services. + """ + + def __init__(self) -> None: + """Initialize mock storage with in-memory data structures.""" + # Simulate cloud storage with dict: session_id -> (content bytes, metadata) + self._storage: dict[str, tuple[bytes, SessionMetadata]] = {} + + async def upload_transcript( + self, + session_id: str, + local_path: Path | str, + ) -> str: + """Upload transcript to mock storage.""" + path = Path(local_path) + if not path.exists(): + raise SessionStorageError( + f"Local transcript not found: {path}", + session_id=session_id, + operation="upload", + ) + + content = path.read_bytes() + import time + + now = time.time() + + # Create metadata + metadata = SessionMetadata( + session_id=session_id, + created_at=now, + updated_at=now, + size_bytes=len(content), + storage_key=f"mock-storage/{session_id}/transcript.jsonl", + ) + + self._storage[session_id] = (content, metadata) + return metadata.storage_key + + async def download_transcript( + self, + session_id: str, + local_path: Path | str, + ) -> bool: + """Download transcript from mock storage.""" + if session_id not in self._storage: + return False + + path = Path(local_path) + path.parent.mkdir(parents=True, exist_ok=True) + + content, _ = self._storage[session_id] + path.write_bytes(content) + return True + + async def exists(self, session_id: str) -> bool: + """Check if session exists in mock storage.""" + return session_id in self._storage + + async def delete(self, session_id: str) -> bool: + """Delete session from mock storage.""" + if session_id not in self._storage: + return False + del self._storage[session_id] + return True + + async def list_sessions( + self, + prefix: str | None = None, + limit: int = 100, + ) -> list[SessionMetadata]: + """List sessions in mock storage.""" + results = [] + for session_id, (_, metadata) in self._storage.items(): + if prefix is None or session_id.startswith(prefix): + results.append(metadata) + if len(results) >= limit: + break + return results + + async def get_metadata(self, session_id: str) -> SessionMetadata | None: + """Get metadata for a session.""" + if session_id not in self._storage: + return None + _, metadata = self._storage[session_id] + return metadata + + +# ============================================================================ +# Test SessionMetadata +# ============================================================================ + + +class TestSessionMetadata: + """Test SessionMetadata dataclass.""" + + def test_create_metadata(self): + """Test creating metadata with all fields.""" + metadata = SessionMetadata( + session_id="session-123", + created_at=1234567890.0, + updated_at=1234567900.0, + size_bytes=1024, + storage_key="s3://bucket/session-123/transcript.jsonl", + ) + + assert metadata.session_id == "session-123" + assert metadata.created_at == 1234567890.0 + assert metadata.updated_at == 1234567900.0 + assert metadata.size_bytes == 1024 + assert metadata.storage_key == "s3://bucket/session-123/transcript.jsonl" + + def test_metadata_equality(self): + """Test metadata equality comparison.""" + metadata1 = SessionMetadata( + session_id="session-123", + created_at=1234567890.0, + updated_at=1234567900.0, + size_bytes=1024, + storage_key="key1", + ) + metadata2 = SessionMetadata( + session_id="session-123", + created_at=1234567890.0, + updated_at=1234567900.0, + size_bytes=1024, + storage_key="key1", + ) + + assert metadata1 == metadata2 + + def test_metadata_inequality(self): + """Test metadata inequality.""" + metadata1 = SessionMetadata( + session_id="session-123", + created_at=1234567890.0, + updated_at=1234567900.0, + size_bytes=1024, + storage_key="key1", + ) + metadata2 = SessionMetadata( + session_id="session-456", + created_at=1234567890.0, + updated_at=1234567900.0, + size_bytes=1024, + storage_key="key2", + ) + + assert metadata1 != metadata2 + + +# ============================================================================ +# Test Protocol Compliance +# ============================================================================ + + +class TestProtocolCompliance: + """Test that MockSessionStorage satisfies SessionStorage protocol.""" + + def test_mock_storage_is_session_storage(self): + """Test MockSessionStorage implements SessionStorage protocol.""" + storage = MockSessionStorage() + assert isinstance(storage, SessionStorage) + + def test_mock_storage_has_all_methods(self): + """Test MockSessionStorage has all required methods.""" + storage = MockSessionStorage() + + # Check all protocol methods exist + assert hasattr(storage, "upload_transcript") + assert hasattr(storage, "download_transcript") + assert hasattr(storage, "exists") + assert hasattr(storage, "delete") + assert hasattr(storage, "list_sessions") + assert hasattr(storage, "get_metadata") + + # Check they're callable + assert callable(storage.upload_transcript) + assert callable(storage.download_transcript) + assert callable(storage.exists) + assert callable(storage.delete) + assert callable(storage.list_sessions) + assert callable(storage.get_metadata) + + +# ============================================================================ +# Test MockSessionStorage Functionality +# ============================================================================ + + +class TestMockSessionStorage: + """Test MockSessionStorage implementation.""" + + @pytest.mark.asyncio + async def test_upload_and_download(self, tmp_path): + """Test uploading and downloading transcripts.""" + storage = MockSessionStorage() + + # Create a test file + transcript_path = tmp_path / "transcript.jsonl" + transcript_path.write_text('{"message": "test"}\n') + + # Upload + key = await storage.upload_transcript("session-123", transcript_path) + assert "session-123" in key + + # Download + download_path = tmp_path / "downloaded.jsonl" + success = await storage.download_transcript("session-123", download_path) + assert success is True + assert download_path.exists() + assert download_path.read_text() == '{"message": "test"}\n' + + @pytest.mark.asyncio + async def test_upload_nonexistent_file(self): + """Test uploading a file that doesn't exist.""" + storage = MockSessionStorage() + + with pytest.raises(SessionStorageError) as exc_info: + await storage.upload_transcript("session-123", "/nonexistent/file.jsonl") + + assert "Local transcript not found" in str(exc_info.value) + assert exc_info.value.session_id == "session-123" + assert exc_info.value.operation == "upload" + + @pytest.mark.asyncio + async def test_download_nonexistent_session(self, tmp_path): + """Test downloading a session that doesn't exist.""" + storage = MockSessionStorage() + + download_path = tmp_path / "downloaded.jsonl" + success = await storage.download_transcript("nonexistent", download_path) + assert success is False + assert not download_path.exists() + + @pytest.mark.asyncio + async def test_exists(self, tmp_path): + """Test checking if session exists.""" + storage = MockSessionStorage() + + # Should not exist initially + exists = await storage.exists("session-123") + assert exists is False + + # Upload a session + transcript_path = tmp_path / "transcript.jsonl" + transcript_path.write_text("test") + await storage.upload_transcript("session-123", transcript_path) + + # Should exist now + exists = await storage.exists("session-123") + assert exists is True + + @pytest.mark.asyncio + async def test_delete(self, tmp_path): + """Test deleting sessions.""" + storage = MockSessionStorage() + + # Upload a session + transcript_path = tmp_path / "transcript.jsonl" + transcript_path.write_text("test") + await storage.upload_transcript("session-123", transcript_path) + + # Delete it + deleted = await storage.delete("session-123") + assert deleted is True + + # Should not exist anymore + exists = await storage.exists("session-123") + assert exists is False + + # Deleting again should return False + deleted = await storage.delete("session-123") + assert deleted is False + + @pytest.mark.asyncio + async def test_list_sessions(self, tmp_path): + """Test listing sessions.""" + storage = MockSessionStorage() + + # Upload multiple sessions + for i in range(5): + path = tmp_path / f"transcript-{i}.jsonl" + path.write_text(f"test {i}") + await storage.upload_transcript(f"session-{i}", path) + + # List all sessions + sessions = await storage.list_sessions() + assert len(sessions) == 5 + assert all(isinstance(meta, SessionMetadata) for meta in sessions) + + @pytest.mark.asyncio + async def test_list_sessions_with_prefix(self, tmp_path): + """Test listing sessions with prefix filter.""" + storage = MockSessionStorage() + + # Upload sessions with different prefixes + for prefix in ["prod", "dev", "test"]: + for i in range(2): + path = tmp_path / f"{prefix}-{i}.jsonl" + path.write_text(f"{prefix} {i}") + await storage.upload_transcript(f"{prefix}-session-{i}", path) + + # List with prefix + sessions = await storage.list_sessions(prefix="prod") + assert len(sessions) == 2 + assert all("prod" in meta.session_id for meta in sessions) + + @pytest.mark.asyncio + async def test_list_sessions_with_limit(self, tmp_path): + """Test listing sessions respects limit.""" + storage = MockSessionStorage() + + # Upload 10 sessions + for i in range(10): + path = tmp_path / f"transcript-{i}.jsonl" + path.write_text(f"test {i}") + await storage.upload_transcript(f"session-{i}", path) + + # List with limit + sessions = await storage.list_sessions(limit=3) + assert len(sessions) <= 3 + + @pytest.mark.asyncio + async def test_get_metadata(self, tmp_path): + """Test getting metadata for a session.""" + storage = MockSessionStorage() + + # Upload a session + transcript_path = tmp_path / "transcript.jsonl" + content = "test content" + transcript_path.write_text(content) + await storage.upload_transcript("session-123", transcript_path) + + # Get metadata + metadata = await storage.get_metadata("session-123") + assert metadata is not None + assert metadata.session_id == "session-123" + assert metadata.size_bytes == len(content) + assert metadata.created_at > 0 + assert metadata.updated_at > 0 + assert "session-123" in metadata.storage_key + + @pytest.mark.asyncio + async def test_get_metadata_nonexistent(self): + """Test getting metadata for nonexistent session.""" + storage = MockSessionStorage() + + metadata = await storage.get_metadata("nonexistent") + assert metadata is None + + +# ============================================================================ +# Test BaseSessionStorage +# ============================================================================ + + +class ConcreteSessionStorage(BaseSessionStorage): + """Concrete implementation of BaseSessionStorage for testing.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.uploaded_keys: list[str] = [] + self.downloaded_keys: list[str] = [] + self._mock_storage: dict[str, bytes] = {} + self._should_fail = False + self._fail_count = 0 + + def set_failure_mode(self, should_fail: bool, fail_count: int = 999): + """Configure failure mode for testing retries.""" + self._should_fail = should_fail + self._fail_count = fail_count + + async def _do_upload(self, key: str, local_path: Path) -> None: + if self._should_fail and self._fail_count > 0: + self._fail_count -= 1 + raise Exception("Mock upload failure") + self.uploaded_keys.append(key) + self._mock_storage[key] = local_path.read_bytes() + + async def _do_download(self, key: str, local_path: Path) -> bool: + if self._should_fail and self._fail_count > 0: + self._fail_count -= 1 + raise Exception("Mock download failure") + self.downloaded_keys.append(key) + if key in self._mock_storage: + local_path.write_bytes(self._mock_storage[key]) + return True + return False + + async def _do_exists(self, key: str) -> bool: + return key in self._mock_storage + + async def _do_delete(self, key: str) -> bool: + if key in self._mock_storage: + del self._mock_storage[key] + return True + return False + + async def _do_list(self, prefix: str, limit: int) -> list[SessionMetadata]: + results = [] + for key in self._mock_storage: + if key.startswith(prefix): + session_id = self._extract_session_id(key) + if session_id: + results.append( + SessionMetadata( + session_id=session_id, + created_at=1234567890.0, + updated_at=1234567890.0, + size_bytes=len(self._mock_storage[key]), + storage_key=key, + ) + ) + if len(results) >= limit: + break + return results + + async def _do_get_metadata(self, key: str) -> SessionMetadata | None: + if key not in self._mock_storage: + return None + session_id = self._extract_session_id(key) + if not session_id: + return None + return SessionMetadata( + session_id=session_id, + created_at=1234567890.0, + updated_at=1234567890.0, + size_bytes=len(self._mock_storage[key]), + storage_key=key, + ) + + +class TestBaseSessionStorage: + """Test BaseSessionStorage abstract base class.""" + + def test_initialization(self): + """Test BaseSessionStorage initialization.""" + storage = ConcreteSessionStorage( + prefix="test-sessions", max_retries=5, retry_delay=2.0 + ) + assert storage.prefix == "test-sessions" + assert storage.max_retries == 5 + assert storage.retry_delay == 2.0 + + def test_prefix_stripping(self): + """Test that trailing slashes are removed from prefix.""" + storage = ConcreteSessionStorage(prefix="test-sessions/") + assert storage.prefix == "test-sessions" + + def test_get_key(self): + """Test key generation.""" + storage = ConcreteSessionStorage(prefix="claude-sessions") + + key = storage._get_key("session-123") + assert key == "claude-sessions/session-123/transcript.jsonl" + + def test_get_key_sanitization(self): + """Test key generation sanitizes session IDs.""" + storage = ConcreteSessionStorage(prefix="claude-sessions") + + # Test path traversal prevention + # The implementation replaces: / -> _, \ -> _, .. -> _ + key = storage._get_key("../../../etc/passwd") + # ".." becomes "_", "/" becomes "_", so "../../../etc/passwd" -> "______etc_passwd" + assert key == "claude-sessions/______etc_passwd/transcript.jsonl" + assert ".." not in key + + # Test various dangerous characters + key = storage._get_key("session/with\\slashes") + # "/" and "\" both become "_" + assert key == "claude-sessions/session_with_slashes/transcript.jsonl" + + def test_extract_session_id(self): + """Test session ID extraction from storage key.""" + storage = ConcreteSessionStorage(prefix="claude-sessions") + + session_id = storage._extract_session_id( + "claude-sessions/session-123/transcript.jsonl" + ) + assert session_id == "session-123" + + def test_extract_session_id_wrong_prefix(self): + """Test session ID extraction with wrong prefix.""" + storage = ConcreteSessionStorage(prefix="claude-sessions") + + session_id = storage._extract_session_id( + "wrong-prefix/session-123/transcript.jsonl" + ) + assert session_id is None + + def test_extract_session_id_invalid_format(self): + """Test session ID extraction with invalid format.""" + storage = ConcreteSessionStorage(prefix="claude-sessions") + + session_id = storage._extract_session_id("claude-sessions/") + assert session_id == "" + + @pytest.mark.asyncio + async def test_upload_success(self, tmp_path): + """Test successful upload.""" + storage = ConcreteSessionStorage() + transcript_path = tmp_path / "transcript.jsonl" + transcript_path.write_text("test content") + + key = await storage.upload_transcript("session-123", transcript_path) + assert key == "claude-sessions/session-123/transcript.jsonl" + assert "claude-sessions/session-123/transcript.jsonl" in storage.uploaded_keys + + @pytest.mark.asyncio + async def test_upload_nonexistent_file(self): + """Test upload raises error for nonexistent file.""" + storage = ConcreteSessionStorage() + + with pytest.raises(SessionStorageError) as exc_info: + await storage.upload_transcript("session-123", "/nonexistent/file.jsonl") + + assert "Local transcript not found" in str(exc_info.value) + assert exc_info.value.session_id == "session-123" + assert exc_info.value.operation == "upload" + + @pytest.mark.asyncio + async def test_upload_retry_logic(self, tmp_path): + """Test upload retry logic with exponential backoff.""" + storage = ConcreteSessionStorage(max_retries=3, retry_delay=0.01) + transcript_path = tmp_path / "transcript.jsonl" + transcript_path.write_text("test") + + # Fail twice, then succeed + storage.set_failure_mode(True, fail_count=2) + + key = await storage.upload_transcript("session-123", transcript_path) + assert key == "claude-sessions/session-123/transcript.jsonl" + + @pytest.mark.asyncio + async def test_upload_retry_exhausted(self, tmp_path): + """Test upload raises error after exhausting retries.""" + storage = ConcreteSessionStorage(max_retries=3, retry_delay=0.01) + transcript_path = tmp_path / "transcript.jsonl" + transcript_path.write_text("test") + + # Always fail + storage.set_failure_mode(True, fail_count=999) + + with pytest.raises(SessionStorageError) as exc_info: + await storage.upload_transcript("session-123", transcript_path) + + assert "Upload failed after 3 attempts" in str(exc_info.value) + assert exc_info.value.session_id == "session-123" + assert exc_info.value.operation == "upload" + assert exc_info.value.original_error is not None + + @pytest.mark.asyncio + async def test_download_success(self, tmp_path): + """Test successful download.""" + storage = ConcreteSessionStorage() + + # First upload something + upload_path = tmp_path / "upload.jsonl" + upload_path.write_text("test content") + await storage.upload_transcript("session-123", upload_path) + + # Then download it + download_path = tmp_path / "download.jsonl" + success = await storage.download_transcript("session-123", download_path) + assert success is True + assert download_path.exists() + assert download_path.read_text() == "test content" + + @pytest.mark.asyncio + async def test_download_creates_parent_directory(self, tmp_path): + """Test download creates parent directories.""" + storage = ConcreteSessionStorage() + + # Upload something + upload_path = tmp_path / "upload.jsonl" + upload_path.write_text("test") + await storage.upload_transcript("session-123", upload_path) + + # Download to nested path + download_path = tmp_path / "nested" / "deep" / "download.jsonl" + success = await storage.download_transcript("session-123", download_path) + assert success is True + assert download_path.exists() + + @pytest.mark.asyncio + async def test_download_not_found(self, tmp_path): + """Test download returns False when session not found.""" + storage = ConcreteSessionStorage() + + download_path = tmp_path / "download.jsonl" + success = await storage.download_transcript("nonexistent", download_path) + assert success is False + + @pytest.mark.asyncio + async def test_download_retry_logic(self, tmp_path): + """Test download retry logic.""" + storage = ConcreteSessionStorage(max_retries=3, retry_delay=0.01) + + # Upload first + upload_path = tmp_path / "upload.jsonl" + upload_path.write_text("test") + await storage.upload_transcript("session-123", upload_path) + + # Fail twice, then succeed + storage.set_failure_mode(True, fail_count=2) + + download_path = tmp_path / "download.jsonl" + success = await storage.download_transcript("session-123", download_path) + assert success is True + + @pytest.mark.asyncio + async def test_download_retry_exhausted(self, tmp_path): + """Test download raises error after exhausting retries.""" + storage = ConcreteSessionStorage(max_retries=3, retry_delay=0.01) + + # Upload first + upload_path = tmp_path / "upload.jsonl" + upload_path.write_text("test") + await storage.upload_transcript("session-123", upload_path) + + # Always fail + storage.set_failure_mode(True, fail_count=999) + + download_path = tmp_path / "download.jsonl" + with pytest.raises(SessionStorageError) as exc_info: + await storage.download_transcript("session-123", download_path) + + assert "Download failed after 3 attempts" in str(exc_info.value) + assert exc_info.value.session_id == "session-123" + assert exc_info.value.operation == "download" + + @pytest.mark.asyncio + async def test_exists(self, tmp_path): + """Test exists method.""" + storage = ConcreteSessionStorage() + + # Should not exist initially + exists = await storage.exists("session-123") + assert exists is False + + # Upload and check again + upload_path = tmp_path / "upload.jsonl" + upload_path.write_text("test") + await storage.upload_transcript("session-123", upload_path) + + exists = await storage.exists("session-123") + assert exists is True + + @pytest.mark.asyncio + async def test_delete(self, tmp_path): + """Test delete method.""" + storage = ConcreteSessionStorage() + + # Upload first + upload_path = tmp_path / "upload.jsonl" + upload_path.write_text("test") + await storage.upload_transcript("session-123", upload_path) + + # Delete + deleted = await storage.delete("session-123") + assert deleted is True + + # Should not exist anymore + exists = await storage.exists("session-123") + assert exists is False + + @pytest.mark.asyncio + async def test_delete_nonexistent(self): + """Test deleting nonexistent session returns False.""" + storage = ConcreteSessionStorage() + + deleted = await storage.delete("nonexistent") + assert deleted is False + + @pytest.mark.asyncio + async def test_list_sessions(self, tmp_path): + """Test listing sessions.""" + storage = ConcreteSessionStorage() + + # Upload multiple sessions + for i in range(5): + path = tmp_path / f"transcript-{i}.jsonl" + path.write_text(f"test {i}") + await storage.upload_transcript(f"session-{i}", path) + + # List all + sessions = await storage.list_sessions() + assert len(sessions) == 5 + + @pytest.mark.asyncio + async def test_list_sessions_with_prefix(self, tmp_path): + """Test listing sessions with prefix.""" + storage = ConcreteSessionStorage(prefix="test") + + # Upload sessions + for i in range(3): + path = tmp_path / f"transcript-{i}.jsonl" + path.write_text(f"test {i}") + await storage.upload_transcript(f"prod-{i}", path) + + # List with prefix + sessions = await storage.list_sessions(prefix="prod") + assert len(sessions) == 3 + + @pytest.mark.asyncio + async def test_get_metadata(self, tmp_path): + """Test getting metadata.""" + storage = ConcreteSessionStorage() + + # Upload first + upload_path = tmp_path / "upload.jsonl" + upload_path.write_text("test content") + await storage.upload_transcript("session-123", upload_path) + + # Get metadata + metadata = await storage.get_metadata("session-123") + assert metadata is not None + assert metadata.session_id == "session-123" + assert metadata.size_bytes > 0 + + @pytest.mark.asyncio + async def test_get_metadata_nonexistent(self): + """Test getting metadata for nonexistent session.""" + storage = ConcreteSessionStorage() + + metadata = await storage.get_metadata("nonexistent") + assert metadata is None + + +# ============================================================================ +# Test SessionSyncManager +# ============================================================================ + + +class TestSessionSyncManager: + """Test SessionSyncManager for cloud storage synchronization.""" + + def test_initialization_default_dir(self): + """Test manager initialization with default directory.""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage) + + assert manager.storage == storage + assert manager.transcript_dir.exists() + assert "claude-sessions" in str(manager.transcript_dir) + + def test_initialization_custom_dir(self, tmp_path): + """Test manager initialization with custom directory.""" + storage = MockSessionStorage() + custom_dir = tmp_path / "custom-transcripts" + + manager = SessionSyncManager(storage, transcript_dir=custom_dir) + + assert manager.storage == storage + assert manager.transcript_dir == custom_dir + assert custom_dir.exists() + + def test_get_local_transcript_path(self, tmp_path): + """Test getting local transcript path.""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + path = manager.get_local_transcript_path("session-123") + assert path == tmp_path / "session-123" / "transcript.jsonl" + assert path.parent.exists() + + def test_get_local_transcript_path_caching(self, tmp_path): + """Test that local transcript paths are cached.""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + path1 = manager.get_local_transcript_path("session-123") + path2 = manager.get_local_transcript_path("session-123") + assert path1 is path2 # Same object + + @pytest.mark.asyncio + async def test_prepare_session_new(self, tmp_path): + """Test preparing a new session (not in cloud).""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + path = await manager.prepare_session("session-123") + assert path == tmp_path / "session-123" / "transcript.jsonl" + assert not path.exists() # New session, no file yet + + @pytest.mark.asyncio + async def test_prepare_session_existing(self, tmp_path): + """Test preparing an existing session (downloads from cloud).""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + # Upload a session to cloud first + cloud_transcript = tmp_path / "cloud-transcript.jsonl" + cloud_transcript.write_text('{"message": "from cloud"}\n') + await storage.upload_transcript("session-123", cloud_transcript) + + # Prepare session (should download) + local_path = await manager.prepare_session("session-123") + assert local_path.exists() + assert local_path.read_text() == '{"message": "from cloud"}\n' + + @pytest.mark.asyncio + async def test_finalize_session_success(self, tmp_path): + """Test finalizing session uploads to cloud.""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + # Create a local transcript + transcript_path = tmp_path / "transcript.jsonl" + transcript_path.write_text('{"message": "test"}\n') + + # Finalize (upload) + await manager.finalize_session("session-123", transcript_path) + + # Verify uploaded to cloud + assert await storage.exists("session-123") + + @pytest.mark.asyncio + async def test_finalize_session_nonexistent_file(self, tmp_path, caplog): + """Test finalizing with nonexistent transcript logs warning.""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + # Finalize with nonexistent file + await manager.finalize_session("session-123", tmp_path / "nonexistent.jsonl") + + # Should log warning but not raise + assert "Transcript not found" in caplog.text + + @pytest.mark.asyncio + async def test_finalize_session_upload_error(self, tmp_path, caplog): + """Test finalize handles upload errors gracefully.""" + # Create storage that will fail on upload + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + # Patch upload to fail + async def failing_upload(*args, **kwargs): + raise Exception("Upload failed") + + storage.upload_transcript = failing_upload + + # Create a local transcript + transcript_path = tmp_path / "transcript.jsonl" + transcript_path.write_text('{"message": "test"}\n') + + # Finalize should not raise, just log error + await manager.finalize_session("session-123", transcript_path) + + assert "Failed to upload session" in caplog.text + + @pytest.mark.asyncio + async def test_finalize_session_cleans_up_tracking(self, tmp_path): + """Test finalize cleans up internal session tracking.""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + # Track a session + manager.get_local_transcript_path("session-123") + assert "session-123" in manager._active_sessions + + # Create transcript and finalize + transcript_path = tmp_path / "transcript.jsonl" + transcript_path.write_text("test") + await manager.finalize_session("session-123", transcript_path) + + # Should be cleaned up + assert "session-123" not in manager._active_sessions + + @pytest.mark.asyncio + async def test_create_stop_hook(self, tmp_path): + """Test creating a stop hook callback.""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + hook = manager.create_stop_hook() + assert callable(hook) + + @pytest.mark.asyncio + async def test_stop_hook_uploads_transcript(self, tmp_path): + """Test stop hook uploads transcript.""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + # Create transcript + transcript_path = tmp_path / "transcript.jsonl" + transcript_path.write_text('{"message": "test"}\n') + + # Create and call hook + hook = manager.create_stop_hook() + input_data = { + "session_id": "session-123", + "transcript_path": str(transcript_path), + } + result = await hook(input_data, None, {}) + + # Should return empty dict + assert result == {} + + # Should have uploaded + assert await storage.exists("session-123") + + @pytest.mark.asyncio + async def test_stop_hook_missing_session_id(self, tmp_path): + """Test stop hook handles missing session_id gracefully.""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + hook = manager.create_stop_hook() + input_data = {"transcript_path": str(tmp_path / "transcript.jsonl")} + result = await hook(input_data, None, {}) + + # Should not raise, just return empty dict + assert result == {} + + @pytest.mark.asyncio + async def test_stop_hook_missing_transcript_path(self): + """Test stop hook handles missing transcript_path gracefully.""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage) + + hook = manager.create_stop_hook() + input_data = {"session_id": "session-123"} + result = await hook(input_data, None, {}) + + # Should not raise, just return empty dict + assert result == {} + + @pytest.mark.asyncio + async def test_stop_hook_handles_errors(self, tmp_path, caplog): + """Test stop hook handles errors without raising.""" + storage = MockSessionStorage() + manager = SessionSyncManager(storage, transcript_dir=tmp_path) + + # Patch finalize_session to fail + async def failing_finalize(*args, **kwargs): + raise Exception("Finalize failed") + + manager.finalize_session = failing_finalize + + # Call hook + hook = manager.create_stop_hook() + input_data = { + "session_id": "session-123", + "transcript_path": str(tmp_path / "transcript.jsonl"), + } + result = await hook(input_data, None, {}) + + # Should return empty dict and log error + assert result == {} + assert "Stop hook failed" in caplog.text + + +# ============================================================================ +# Test SessionStorageError +# ============================================================================ + + +class TestSessionStorageError: + """Test SessionStorageError exception.""" + + def test_basic_error(self): + """Test basic error creation.""" + error = SessionStorageError("Upload failed") + assert "Upload failed" in str(error) + assert error.session_id is None + assert error.operation is None + assert error.original_error is None + + def test_error_with_session_id(self): + """Test error with session ID.""" + error = SessionStorageError("Upload failed", session_id="session-123") + assert "Upload failed" in str(error) + assert "session-123" in str(error) + assert error.session_id == "session-123" + + def test_error_with_operation(self): + """Test error with operation.""" + error = SessionStorageError("Failed", operation="upload") + assert "upload" in str(error) + assert error.operation == "upload" + + def test_error_with_original_error(self): + """Test error with original error.""" + original = Exception("Connection timeout") + error = SessionStorageError("Upload failed", original_error=original) + assert "Connection timeout" in str(error) + assert error.original_error is original + + def test_error_with_all_fields(self): + """Test error with all fields.""" + original = Exception("Network error") + error = SessionStorageError( + "Upload failed", + session_id="session-123", + operation="upload", + original_error=original, + ) + assert "Upload failed" in str(error) + assert "session-123" in str(error) + assert "upload" in str(error) + assert "Network error" in str(error) + assert error.session_id == "session-123" + assert error.operation == "upload" + assert error.original_error is original + + def test_error_is_exception(self): + """Test SessionStorageError is an Exception.""" + error = SessionStorageError("Test") + assert isinstance(error, Exception) + + def test_error_can_be_raised(self): + """Test SessionStorageError can be raised and caught.""" + with pytest.raises(SessionStorageError) as exc_info: + raise SessionStorageError("Test error", session_id="test-123") + + assert "Test error" in str(exc_info.value) + assert exc_info.value.session_id == "test-123"