Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ hatch test --all # Test all Python versions (3.10-3.13)
- Use `moto` for mocking AWS services
- Use `pytest.mark.asyncio` for async tests
- Keep tests focused and independent
- Import packages at the top of the test files

## MCP Tasks (Experimental)

Expand Down
8 changes: 6 additions & 2 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,12 +492,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
"""
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html
if "cachePoint" in content:
return {"cachePoint": {"type": content["cachePoint"]["type"]}}
cache_point = content["cachePoint"]
result: dict[str, Any] = {"type": cache_point["type"]}
if "ttl" in cache_point:
result["ttl"] = cache_point["ttl"]
return {"cachePoint": result}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html
if "document" in content:
document = content["document"]
result: dict[str, Any] = {}
result = {}

# Handle required fields (all optional due to total=False)
if "name" in document:
Expand Down
4 changes: 3 additions & 1 deletion src/strands/types/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing import Literal

from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict

from .citations import CitationsContentBlock
from .media import DocumentContent, ImageContent, VideoContent
Expand Down Expand Up @@ -66,9 +66,11 @@ class CachePoint(TypedDict):
Attributes:
type: The type of cache point, typically "default".
ttl: Optional TTL duration for cache entries. Valid values are "5m" (5 minutes) or "1h" (1 hour).
"""

type: str
ttl: NotRequired[Literal["5m", "1h"]]


class ContentBlock(TypedDict, total=False):
Expand Down
47 changes: 47 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,6 +2050,53 @@ def test_format_request_filters_cache_point_content_blocks(model, model_id):
assert "extraField" not in cache_point_block


def test_format_request_preserves_cache_point_ttl(model, model_id):
"""Test that format_request preserves the ttl field in cachePoint content blocks."""
messages = [
{
"role": "user",
"content": [
{
"cachePoint": {
"type": "default",
"ttl": "1h",
}
},
],
}
]

formatted_request = model._format_request(messages)

cache_point_block = formatted_request["messages"][0]["content"][0]["cachePoint"]
expected = {"type": "default", "ttl": "1h"}
assert cache_point_block == expected
assert cache_point_block["ttl"] == "1h"


def test_format_request_cache_point_without_ttl(model, model_id):
"""Test that cache points work without ttl field (backward compatibility)."""
messages = [
{
"role": "user",
"content": [
{
"cachePoint": {
"type": "default",
}
},
],
}
]

formatted_request = model._format_request(messages)

cache_point_block = formatted_request["messages"][0]["content"][0]["cachePoint"]
expected = {"type": "default"}
assert cache_point_block == expected
assert "ttl" not in cache_point_block


def test_config_validation_warns_on_unknown_keys(bedrock_client, captured_warnings):
"""Test that unknown config keys emit a warning."""
BedrockModel(model_id="test-model", invalid_param="test")
Expand Down
151 changes: 151 additions & 0 deletions tests_integ/models/test_model_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import time
import uuid

import pydantic
import pytest

Expand Down Expand Up @@ -323,3 +326,151 @@ def test_multi_prompt_system_content():
agent = Agent(system_prompt=system_prompt_content, load_tools_from_directory=False)
# just verifying there is no failure
agent("Hello!")


def test_prompt_caching_with_5m_ttl(streaming_model):
"""Test prompt caching with 5 minute TTL and verify cache metrics.

This test verifies:
1. First call creates cache (cacheWriteInputTokens > 0)
2. Second call reads from cache (cacheReadInputTokens > 0)
"""
# Use unique identifier to avoid cache conflicts between test runs
unique_id = str(uuid.uuid4())
# Minimum 1024 tokens required for caching
large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 200)

system_prompt_with_cache = [
{"text": large_context},
{"cachePoint": {"type": "default", "ttl": "5m"}},
{"text": "You are a helpful assistant."},
]

agent = Agent(
model=streaming_model,
system_prompt=system_prompt_with_cache,
load_tools_from_directory=False,
)

# First call should create the cache (cache write)
result1 = agent("What is 2+2?")
assert len(str(result1)) > 0

# Verify cache write occurred on first call
assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, (
"Expected cacheWriteInputTokens > 0 on first call"
)

# Second call should use the cached content (cache read)
result2 = agent("What is 3+3?")
assert len(str(result2)) > 0

# Verify cache read occurred on second call
assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, (
"Expected cacheReadInputTokens > 0 on second call"
)


def test_prompt_caching_with_1h_ttl(non_streaming_model):
"""Test prompt caching with 1 hour TTL and verify cache metrics.

Uses unique content per test run to avoid cache conflicts with concurrent CI runs.
Even with 1hr TTL, unique content ensures cache entries don't interfere across tests.
"""
# Use timestamp to ensure unique content per test run (avoids CI conflicts)
unique_id = str(int(time.time() * 1000000)) # microsecond timestamp
# Minimum 1024 tokens required for caching
large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 200)

system_prompt_with_cache = [
{"text": large_context},
{"cachePoint": {"type": "default", "ttl": "1h"}},
{"text": "You are a helpful assistant."},
]

agent = Agent(
model=non_streaming_model,
system_prompt=system_prompt_with_cache,
load_tools_from_directory=False,
)

# First call should create the cache
result1 = agent("What is 2+2?")
assert len(str(result1)) > 0

# Verify cache write occurred
assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, (
"Expected cacheWriteInputTokens > 0 on first call with 1h TTL"
)

# Second call should use the cached content
result2 = agent("What is 3+3?")
assert len(str(result2)) > 0

# Verify cache read occurred
assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, (
"Expected cacheReadInputTokens > 0 on second call with 1h TTL"
)


def test_prompt_caching_with_ttl_in_messages(streaming_model):
"""Test prompt caching with TTL in message content and verify cache metrics."""
agent = Agent(model=streaming_model, load_tools_from_directory=False)

unique_id = str(uuid.uuid4())
# Large content block to cache (minimum 1024 tokens)
large_text = f"Important context for test {unique_id}: " + ("This is critical information. " * 200)

content_with_cache = [
{"text": large_text},
{"cachePoint": {"type": "default", "ttl": "5m"}},
{"text": "Based on the context above, what is 5+5?"},
]

# First call creates cache
result1 = agent(content_with_cache)
assert len(str(result1)) > 0

# Verify cache write in message content
assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, (
"Expected cacheWriteInputTokens > 0 when caching message content"
)

# Subsequent call should use cache
result2 = agent("What about 10+10?")
assert len(str(result2)) > 0

# Verify cache read on subsequent call
assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, (
"Expected cacheReadInputTokens > 0 on subsequent call"
)


def test_prompt_caching_backward_compatibility_no_ttl(non_streaming_model):
"""Test that prompt caching works without TTL (backward compatibility).

Verifies that cache points work correctly when TTL is not specified,
maintaining backward compatibility with existing code.
"""
unique_id = str(uuid.uuid4())
large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 200)

system_prompt_with_cache = [
{"text": large_context},
{"cachePoint": {"type": "default"}}, # No TTL specified
{"text": "You are a helpful assistant."},
]

agent = Agent(
model=non_streaming_model,
system_prompt=system_prompt_with_cache,
load_tools_from_directory=False,
)

result = agent("Hello!")
assert len(str(result)) > 0

# Verify cache write occurred even without TTL
assert result.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, (
"Expected cacheWriteInputTokens > 0 even without TTL specified"
)