Skip to content

Commit 2e0aece

Browse files
committed
fix: convert ContentBlock to proper Union type (#148)
Refactored ContentBlock from a single TypedDict with optional fields to a Union of 9 specific TypedDict classes to enforce the Bedrock API constraint that each ContentBlock must contain exactly one content type. Changes: - Created specific TypedDict classes: ContentBlockText, ContentBlockImage, ContentBlockDocument, ContentBlockVideo, ContentBlockToolUse, ContentBlockToolResult, ContentBlockGuardContent, ContentBlockReasoningContent, ContentBlockCitations - Added TypeGuard functions to enable proper type narrowing - Added CONTENT_BLOCK_KEYS constant for runtime validation - Updated all model providers to use type guard functions - Updated tests to use ContentBlockText for instantiation - Fixed agent.py to use CONTENT_BLOCK_KEYS instead of __annotations__ This ensures type safety and prevents invalid ContentBlock structures while maintaining full mypy compliance with zero type errors.
1 parent 033574b commit 2e0aece

29 files changed

+537
-178
lines changed

src/strands/agent/agent.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,16 @@
5858
from ..tools.watcher import ToolWatcher
5959
from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
6060
from ..types.agent import AgentInput
61-
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
61+
from ..types.content import (
62+
CONTENT_BLOCK_KEYS,
63+
ContentBlock,
64+
ContentBlockText,
65+
Message,
66+
Messages,
67+
SystemContentBlock,
68+
is_tool_result_block,
69+
is_tool_use_block,
70+
)
6271
from ..types.exceptions import ContextWindowOverflowException
6372
from ..types.traces import AttributeValue
6473
from .agent_result import AgentResult
@@ -717,7 +726,9 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
717726
"Agents latest message is toolUse, appending a toolResult message to have valid conversation."
718727
)
719728
tool_use_ids = [
720-
content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content
729+
content["toolUse"]["toolUseId"]
730+
for content in self.messages[-1]["content"]
731+
if is_tool_use_block(content)
721732
]
722733
await self._append_messages(
723734
{
@@ -740,7 +751,7 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
740751
messages = cast(Messages, prompt)
741752

742753
# Check if all items are content blocks
743-
elif all(any(key in ContentBlock.__annotations__.keys() for key in item) for item in prompt):
754+
elif all(any(key in CONTENT_BLOCK_KEYS for key in item) for item in prompt):
744755
# Treat as List[ContentBlock] input - convert to user message
745756
# This allows invalid structures to be passed through to the model
746757
messages = [{"role": "user", "content": cast(list[ContentBlock], prompt)}]
@@ -835,14 +846,14 @@ def _redact_user_content(self, content: list[ContentBlock], redact_message: str)
835846
- otherwise, the entire content of the message is replaced
836847
with a single text block with the redact message.
837848
"""
838-
redacted_content = []
849+
redacted_content: list[ContentBlock] = []
839850
for block in content:
840-
if "toolResult" in block:
851+
if is_tool_result_block(block):
841852
block["toolResult"]["content"] = [{"text": redact_message}]
842853
redacted_content.append(block)
843854

844855
if not redacted_content:
845856
# Text content is added only if no toolResult blocks were found
846-
redacted_content = [{"text": redact_message}]
857+
redacted_content = [ContentBlockText(text=redact_message)]
847858

848859
return redacted_content

src/strands/agent/agent_result.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from ..interrupt import Interrupt
1212
from ..telemetry.metrics import EventLoopMetrics
13-
from ..types.content import Message
13+
from ..types.content import Message, is_text_block
1414
from ..types.streaming import StopReason
1515

1616

@@ -48,8 +48,8 @@ def __str__(self) -> str:
4848

4949
result = ""
5050
for item in content_array:
51-
if isinstance(item, dict) and "text" in item:
52-
result += item.get("text", "") + "\n"
51+
if is_text_block(item):
52+
result += item["text"] + "\n"
5353

5454
if not result and self.structured_output:
5555
result = self.structured_output.model_dump_json()

src/strands/agent/conversation_manager/sliding_window_conversation_manager.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""Sliding window conversation history management."""
22

33
import logging
4-
from typing import TYPE_CHECKING, Any, Optional
4+
from typing import TYPE_CHECKING, Any, Optional, cast
55

66
if TYPE_CHECKING:
77
from ...agent.agent import Agent
88

9-
from ...types.content import Messages
9+
from ...types.content import ContentBlockToolResult, Messages, is_tool_result_block
1010
from ...types.exceptions import ContextWindowOverflowException
1111
from .conversation_manager import ConversationManager
1212

@@ -132,21 +132,23 @@ def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool:
132132
changes_made = False
133133
tool_result_too_large_message = "The tool result was too large!"
134134
for i, content in enumerate(message.get("content", [])):
135-
if isinstance(content, dict) and "toolResult" in content:
135+
if is_tool_result_block(content):
136136
tool_result_content_text = next(
137137
(item["text"] for item in content["toolResult"]["content"] if "text" in item),
138138
"",
139139
)
140+
# Cast to ensure type narrowing for indexed access
141+
content_block = cast(ContentBlockToolResult, message["content"][i])
140142
# make the overwriting logic togglable
141143
if (
142-
message["content"][i]["toolResult"]["status"] == "error"
144+
content_block["toolResult"]["status"] == "error"
143145
and tool_result_content_text == tool_result_too_large_message
144146
):
145147
logger.info("ToolResult has already been updated, skipping overwrite")
146148
return False
147149
# Update status to error with informative message
148-
message["content"][i]["toolResult"]["status"] = "error"
149-
message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}]
150+
content_block["toolResult"]["status"] = "error"
151+
content_block["toolResult"]["content"] = [{"text": tool_result_too_large_message}]
150152
changes_made = True
151153

152154
return changes_made

src/strands/event_loop/_recover_message_on_max_tokens_reached.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
import logging
99

10-
from ..types.content import ContentBlock, Message
11-
from ..types.tools import ToolUse
10+
from ..types.content import ContentBlock, Message, is_tool_use_block
1211

1312
logger = logging.getLogger(__name__)
1413

@@ -52,11 +51,12 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message:
5251

5352
valid_content: list[ContentBlock] = []
5453
for content in message["content"] or []:
55-
tool_use: ToolUse | None = content.get("toolUse")
56-
if not tool_use:
54+
if not is_tool_use_block(content):
5755
valid_content.append(content)
5856
continue
5957

58+
tool_use = content["toolUse"]
59+
6060
# Replace all tool uses with error messages when max_tokens is reached
6161
display_name = tool_use.get("name") or "<unknown>"
6262
logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name)

src/strands/event_loop/streaming.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,15 @@
2222
TypedEvent,
2323
)
2424
from ..types.citations import CitationsContentBlock
25-
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
25+
from ..types.content import (
26+
ContentBlock,
27+
ContentBlockReasoningContent,
28+
Message,
29+
Messages,
30+
SystemContentBlock,
31+
is_text_block,
32+
is_tool_use_block,
33+
)
2634
from ..types.streaming import (
2735
ContentBlockDeltaEvent,
2836
ContentBlockStart,
@@ -69,7 +77,7 @@ def _normalize_messages(messages: Messages) -> Messages:
6977
# Ensure the tool-uses always have valid names before sending
7078
# https://github.com/strands-agents/sdk-python/issues/1069
7179
for item in content:
72-
if "toolUse" in item:
80+
if is_tool_use_block(item):
7381
has_tool_use = True
7482
tool_use: ToolUse = item["toolUse"]
7583

@@ -82,13 +90,13 @@ def _normalize_messages(messages: Messages) -> Messages:
8290
if has_tool_use:
8391
# Remove blank 'text' items for assistant messages
8492
before_len = len(content)
85-
content[:] = [item for item in content if "text" not in item or item["text"].strip()]
93+
content[:] = [item for item in content if not is_text_block(item) or item["text"].strip()]
8694
if not removed_blank_message_content_text and before_len != len(content):
8795
removed_blank_message_content_text = True
8896
else:
8997
# Replace blank 'text' with '[blank text]' for assistant messages
9098
for item in content:
91-
if "text" in item and not item["text"].strip():
99+
if is_text_block(item) and not item["text"].strip():
92100
replaced_blank_message_content_text = True
93101
item["text"] = "[blank text]"
94102

@@ -136,13 +144,13 @@ def remove_blank_messages_content_text(messages: Messages) -> Messages:
136144
if has_tool_use:
137145
# Remove blank 'text' items for assistant messages
138146
before_len = len(content)
139-
content[:] = [item for item in content if "text" not in item or item["text"].strip()]
147+
content[:] = [item for item in content if not is_text_block(item) or item["text"].strip()]
140148
if not removed_blank_message_content_text and before_len != len(content):
141149
removed_blank_message_content_text = True
142150
else:
143151
# Replace blank 'text' with '[blank text]' for assistant messages
144152
for item in content:
145-
if "text" in item and not item["text"].strip():
153+
if is_text_block(item) and not item["text"].strip():
146154
replaced_blank_message_content_text = True
147155
item["text"] = "[blank text]"
148156

@@ -298,7 +306,7 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
298306
state["text"] = ""
299307

300308
elif reasoning_text:
301-
content_block: ContentBlock = {
309+
content_block: ContentBlockReasoningContent = {
302310
"reasoningContent": {
303311
"reasoningText": {
304312
"text": state["reasoningText"],

src/strands/models/anthropic.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,16 @@
1515

1616
from ..event_loop.streaming import process_stream
1717
from ..tools.structured_output.structured_output_utils import convert_pydantic_to_tool_spec
18-
from ..types.content import ContentBlock, Messages
18+
from ..types.content import (
19+
ContentBlock,
20+
Messages,
21+
is_document_block,
22+
is_image_block,
23+
is_reasoning_content_block,
24+
is_text_block,
25+
is_tool_result_block,
26+
is_tool_use_block,
27+
)
1928
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
2029
from ..types.streaming import StreamEvent
2130
from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec
@@ -108,7 +117,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
108117
Raises:
109118
TypeError: If the content block type cannot be converted to an Anthropic-compatible format.
110119
"""
111-
if "document" in content:
120+
if is_document_block(content):
112121
mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream")
113122
return {
114123
"source": {
@@ -124,7 +133,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
124133
"type": "document",
125134
}
126135

127-
if "image" in content:
136+
if is_image_block(content):
128137
return {
129138
"source": {
130139
"data": base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8"),
@@ -134,25 +143,25 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
134143
"type": "image",
135144
}
136145

137-
if "reasoningContent" in content:
146+
if is_reasoning_content_block(content):
138147
return {
139148
"signature": content["reasoningContent"]["reasoningText"]["signature"],
140149
"thinking": content["reasoningContent"]["reasoningText"]["text"],
141150
"type": "thinking",
142151
}
143152

144-
if "text" in content:
153+
if is_text_block(content):
145154
return {"text": content["text"], "type": "text"}
146155

147-
if "toolUse" in content:
156+
if is_tool_use_block(content):
148157
return {
149158
"id": content["toolUse"]["toolUseId"],
150159
"input": content["toolUse"]["input"],
151160
"name": content["toolUse"]["name"],
152161
"type": "tool_use",
153162
}
154163

155-
if "toolResult" in content:
164+
if is_tool_result_block(content):
156165
return {
157166
"content": [
158167
self._format_request_message_content(

0 commit comments

Comments
 (0)