diff --git a/AGENTS.md b/AGENTS.md index 9199d50fa..78995a6ff 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -413,6 +413,7 @@ hatch test --all # Test all Python versions (3.10-3.13) - Use `moto` for mocking AWS services - Use `pytest.mark.asyncio` for async tests - Keep tests focused and independent +- Import packages at the top of the test files ## MCP Tasks (Experimental) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 596936e6f..bb30f8942 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -492,12 +492,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An """ # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html if "cachePoint" in content: - return {"cachePoint": {"type": content["cachePoint"]["type"]}} + cache_point = content["cachePoint"] + result: dict[str, Any] = {"type": cache_point["type"]} + if "ttl" in cache_point: + result["ttl"] = cache_point["ttl"] + return {"cachePoint": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html if "document" in content: document = content["document"] - result: dict[str, Any] = {} + result = {} # Handle required fields (all optional due to total=False) if "name" in document: diff --git a/src/strands/types/content.py b/src/strands/types/content.py index d75dbb87f..164f199bb 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -8,7 +8,7 @@ from typing import Literal -from typing_extensions import TypedDict +from typing_extensions import NotRequired, TypedDict from .citations import CitationsContentBlock from .media import DocumentContent, ImageContent, VideoContent @@ -66,9 +66,11 @@ class CachePoint(TypedDict): Attributes: type: The type of cache point, typically "default". + ttl: Optional TTL duration for cache entries. Valid values are "5m" (5 minutes) or "1h" (1 hour). """ type: str + ttl: NotRequired[Literal["5m", "1h"]] class ContentBlock(TypedDict, total=False): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 1410e129b..a013f8e68 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2050,6 +2050,53 @@ def test_format_request_filters_cache_point_content_blocks(model, model_id): assert "extraField" not in cache_point_block +def test_format_request_preserves_cache_point_ttl(model, model_id): + """Test that format_request preserves the ttl field in cachePoint content blocks.""" + messages = [ + { + "role": "user", + "content": [ + { + "cachePoint": { + "type": "default", + "ttl": "1h", + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + cache_point_block = formatted_request["messages"][0]["content"][0]["cachePoint"] + expected = {"type": "default", "ttl": "1h"} + assert cache_point_block == expected + assert cache_point_block["ttl"] == "1h" + + +def test_format_request_cache_point_without_ttl(model, model_id): + """Test that cache points work without ttl field (backward compatibility).""" + messages = [ + { + "role": "user", + "content": [ + { + "cachePoint": { + "type": "default", + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + cache_point_block = formatted_request["messages"][0]["content"][0]["cachePoint"] + expected = {"type": "default"} + assert cache_point_block == expected + assert "ttl" not in cache_point_block + + def test_config_validation_warns_on_unknown_keys(bedrock_client, captured_warnings): """Test that unknown config keys emit a warning.""" BedrockModel(model_id="test-model", invalid_param="test") diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 0b3aa7b47..fa9ef65ba 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -1,3 +1,6 @@ +import time +import uuid + import pydantic import pytest @@ -323,3 +326,151 @@ def test_multi_prompt_system_content(): agent = Agent(system_prompt=system_prompt_content, load_tools_from_directory=False) # just verifying there is no failure agent("Hello!") + + +def test_prompt_caching_with_5m_ttl(streaming_model): + """Test prompt caching with 5 minute TTL and verify cache metrics. + + This test verifies: + 1. First call creates cache (cacheWriteInputTokens > 0) + 2. Second call reads from cache (cacheReadInputTokens > 0) + """ + # Use unique identifier to avoid cache conflicts between test runs + unique_id = str(uuid.uuid4()) + # Minimum 1024 tokens required for caching + large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 200) + + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default", "ttl": "5m"}}, + {"text": "You are a helpful assistant."}, + ] + + agent = Agent( + model=streaming_model, + system_prompt=system_prompt_with_cache, + load_tools_from_directory=False, + ) + + # First call should create the cache (cache write) + result1 = agent("What is 2+2?") + assert len(str(result1)) > 0 + + # Verify cache write occurred on first call + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 on first call" + ) + + # Second call should use the cached content (cache read) + result2 = agent("What is 3+3?") + assert len(str(result2)) > 0 + + # Verify cache read occurred on second call + assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, ( + "Expected cacheReadInputTokens > 0 on second call" + ) + + +def test_prompt_caching_with_1h_ttl(non_streaming_model): + """Test prompt caching with 1 hour TTL and verify cache metrics. + + Uses unique content per test run to avoid cache conflicts with concurrent CI runs. + Even with 1hr TTL, unique content ensures cache entries don't interfere across tests. + """ + # Use timestamp to ensure unique content per test run (avoids CI conflicts) + unique_id = str(int(time.time() * 1000000)) # microsecond timestamp + # Minimum 1024 tokens required for caching + large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 200) + + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default", "ttl": "1h"}}, + {"text": "You are a helpful assistant."}, + ] + + agent = Agent( + model=non_streaming_model, + system_prompt=system_prompt_with_cache, + load_tools_from_directory=False, + ) + + # First call should create the cache + result1 = agent("What is 2+2?") + assert len(str(result1)) > 0 + + # Verify cache write occurred + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 on first call with 1h TTL" + ) + + # Second call should use the cached content + result2 = agent("What is 3+3?") + assert len(str(result2)) > 0 + + # Verify cache read occurred + assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, ( + "Expected cacheReadInputTokens > 0 on second call with 1h TTL" + ) + + +def test_prompt_caching_with_ttl_in_messages(streaming_model): + """Test prompt caching with TTL in message content and verify cache metrics.""" + agent = Agent(model=streaming_model, load_tools_from_directory=False) + + unique_id = str(uuid.uuid4()) + # Large content block to cache (minimum 1024 tokens) + large_text = f"Important context for test {unique_id}: " + ("This is critical information. " * 200) + + content_with_cache = [ + {"text": large_text}, + {"cachePoint": {"type": "default", "ttl": "5m"}}, + {"text": "Based on the context above, what is 5+5?"}, + ] + + # First call creates cache + result1 = agent(content_with_cache) + assert len(str(result1)) > 0 + + # Verify cache write in message content + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 when caching message content" + ) + + # Subsequent call should use cache + result2 = agent("What about 10+10?") + assert len(str(result2)) > 0 + + # Verify cache read on subsequent call + assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, ( + "Expected cacheReadInputTokens > 0 on subsequent call" + ) + + +def test_prompt_caching_backward_compatibility_no_ttl(non_streaming_model): + """Test that prompt caching works without TTL (backward compatibility). + + Verifies that cache points work correctly when TTL is not specified, + maintaining backward compatibility with existing code. + """ + unique_id = str(uuid.uuid4()) + large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 200) + + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default"}}, # No TTL specified + {"text": "You are a helpful assistant."}, + ] + + agent = Agent( + model=non_streaming_model, + system_prompt=system_prompt_with_cache, + load_tools_from_directory=False, + ) + + result = agent("Hello!") + assert len(str(result)) > 0 + + # Verify cache write occurred even without TTL + assert result.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 even without TTL specified" + )