diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 256c74415..52a552383 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -58,7 +58,16 @@ from ..tools.watcher import ToolWatcher from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput -from ..types.content import ContentBlock, Message, Messages, SystemContentBlock +from ..types.content import ( + CONTENT_BLOCK_KEYS, + ContentBlock, + ContentBlockText, + Message, + Messages, + SystemContentBlock, + is_tool_result_block, + is_tool_use_block, +) from ..types.exceptions import ContextWindowOverflowException from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -720,7 +729,9 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: "Agents latest message is toolUse, appending a toolResult message to have valid conversation." ) tool_use_ids = [ - content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content + content["toolUse"]["toolUseId"] + for content in self.messages[-1]["content"] + if is_tool_use_block(content) ] await self._append_messages( { @@ -743,7 +754,7 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: messages = cast(Messages, prompt) # Check if all items are content blocks - elif all(any(key in ContentBlock.__annotations__.keys() for key in item) for item in prompt): + elif all(any(key in CONTENT_BLOCK_KEYS for key in item) for item in prompt): # Treat as List[ContentBlock] input - convert to user message # This allows invalid structures to be passed through to the model messages = [{"role": "user", "content": cast(list[ContentBlock], prompt)}] @@ -838,14 +849,14 @@ def _redact_user_content(self, content: list[ContentBlock], redact_message: str) - otherwise, the entire content of the message is replaced with a single text block with the redact message. """ - redacted_content = [] + redacted_content: list[ContentBlock] = [] for block in content: - if "toolResult" in block: + if is_tool_result_block(block): block["toolResult"]["content"] = [{"text": redact_message}] redacted_content.append(block) if not redacted_content: # Text content is added only if no toolResult blocks were found - redacted_content = [{"text": redact_message}] + redacted_content = [ContentBlockText(text=redact_message)] return redacted_content diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index ef8a11029..b77303ea2 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -10,7 +10,7 @@ from ..interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics -from ..types.content import Message +from ..types.content import Message, is_text_block from ..types.streaming import StopReason @@ -48,8 +48,8 @@ def __str__(self) -> str: result = "" for item in content_array: - if isinstance(item, dict) and "text" in item: - result += item.get("text", "") + "\n" + if is_text_block(item): + result += item["text"] + "\n" if not result and self.structured_output: result = self.structured_output.model_dump_json() diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index a063e55eb..a5ab795bd 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,13 +1,13 @@ """Sliding window conversation history management.""" import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, cast if TYPE_CHECKING: from ...agent.agent import Agent from ...hooks import BeforeModelCallEvent, HookRegistry -from ...types.content import Messages +from ...types.content import ContentBlockToolResult, Messages, is_tool_result_block from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -216,21 +216,23 @@ def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: changes_made = False tool_result_too_large_message = "The tool result was too large!" for i, content in enumerate(message.get("content", [])): - if isinstance(content, dict) and "toolResult" in content: + if is_tool_result_block(content): tool_result_content_text = next( (item["text"] for item in content["toolResult"]["content"] if "text" in item), "", ) + # Cast to ensure type narrowing for indexed access + content_block = cast(ContentBlockToolResult, message["content"][i]) # make the overwriting logic togglable if ( - message["content"][i]["toolResult"]["status"] == "error" + content_block["toolResult"]["status"] == "error" and tool_result_content_text == tool_result_too_large_message ): logger.info("ToolResult has already been updated, skipping overwrite") return False # Update status to error with informative message - message["content"][i]["toolResult"]["status"] = "error" - message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}] + content_block["toolResult"]["status"] = "error" + content_block["toolResult"]["content"] = [{"text": tool_result_too_large_message}] changes_made = True return changes_made diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py index ab6fb4abe..cf3d2a9f5 100644 --- a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -7,8 +7,7 @@ import logging -from ..types.content import ContentBlock, Message -from ..types.tools import ToolUse +from ..types.content import ContentBlock, Message, is_tool_use_block logger = logging.getLogger(__name__) @@ -52,11 +51,12 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: valid_content: list[ContentBlock] = [] for content in message["content"] or []: - tool_use: ToolUse | None = content.get("toolUse") - if not tool_use: + if not is_tool_use_block(content): valid_content.append(content) continue + tool_use = content["toolUse"] + # Replace all tool uses with error messages when max_tokens is reached display_name = tool_use.get("name") or "" logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 804f90a1d..b2a50d71b 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -22,7 +22,15 @@ TypedEvent, ) from ..types.citations import CitationsContentBlock -from ..types.content import ContentBlock, Message, Messages, SystemContentBlock +from ..types.content import ( + ContentBlock, + ContentBlockReasoningContent, + Message, + Messages, + SystemContentBlock, + is_text_block, + is_tool_use_block, +) from ..types.streaming import ( ContentBlockDeltaEvent, ContentBlockStart, @@ -69,7 +77,7 @@ def _normalize_messages(messages: Messages) -> Messages: # Ensure the tool-uses always have valid names before sending # https://github.com/strands-agents/sdk-python/issues/1069 for item in content: - if "toolUse" in item: + if is_tool_use_block(item): has_tool_use = True tool_use: ToolUse = item["toolUse"] @@ -82,13 +90,13 @@ def _normalize_messages(messages: Messages) -> Messages: if has_tool_use: # Remove blank 'text' items for assistant messages before_len = len(content) - content[:] = [item for item in content if "text" not in item or item["text"].strip()] + content[:] = [item for item in content if not is_text_block(item) or item["text"].strip()] if not removed_blank_message_content_text and before_len != len(content): removed_blank_message_content_text = True else: # Replace blank 'text' with '[blank text]' for assistant messages for item in content: - if "text" in item and not item["text"].strip(): + if is_text_block(item) and not item["text"].strip(): replaced_blank_message_content_text = True item["text"] = "[blank text]" @@ -136,13 +144,13 @@ def remove_blank_messages_content_text(messages: Messages) -> Messages: if has_tool_use: # Remove blank 'text' items for assistant messages before_len = len(content) - content[:] = [item for item in content if "text" not in item or item["text"].strip()] + content[:] = [item for item in content if not is_text_block(item) or item["text"].strip()] if not removed_blank_message_content_text and before_len != len(content): removed_blank_message_content_text = True else: # Replace blank 'text' with '[blank text]' for assistant messages for item in content: - if "text" in item and not item["text"].strip(): + if is_text_block(item) and not item["text"].strip(): replaced_blank_message_content_text = True item["text"] = "[blank text]" @@ -298,7 +306,7 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: state["text"] = "" elif reasoning_text: - content_block: ContentBlock = { + content_block: ContentBlockReasoningContent = { "reasoningContent": { "reasoningText": { "text": state["reasoningText"], diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 68b234729..1854f0c2a 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -15,7 +15,16 @@ from ..event_loop.streaming import process_stream from ..tools.structured_output.structured_output_utils import convert_pydantic_to_tool_spec -from ..types.content import ContentBlock, Messages +from ..types.content import ( + ContentBlock, + Messages, + is_document_block, + is_image_block, + is_reasoning_content_block, + is_text_block, + is_tool_result_block, + is_tool_use_block, +) from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec @@ -108,7 +117,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Raises: TypeError: If the content block type cannot be converted to an Anthropic-compatible format. """ - if "document" in content: + if is_document_block(content): mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") return { "source": { @@ -124,7 +133,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An "type": "document", } - if "image" in content: + if is_image_block(content): return { "source": { "data": base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8"), @@ -134,17 +143,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An "type": "image", } - if "reasoningContent" in content: + if is_reasoning_content_block(content): return { "signature": content["reasoningContent"]["reasoningText"]["signature"], "thinking": content["reasoningContent"]["reasoningText"]["text"], "type": "thinking", } - if "text" in content: + if is_text_block(content): return {"text": content["text"], "type": "text"} - if "toolUse" in content: + if is_tool_use_block(content): return { "id": content["toolUse"]["toolUseId"], "input": content["toolUse"]["input"], @@ -152,7 +161,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An "type": "tool_use", } - if "toolResult" in content: + if is_tool_result_block(content): return { "content": [ self._format_request_message_content( diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 08d8f400c..ee1795b97 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -20,7 +20,20 @@ from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec from ..tools._tool_helpers import noop_tool -from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.content import ( + ContentBlock, + Messages, + SystemContentBlock, + is_citations_block, + is_document_block, + is_guard_content_block, + is_image_block, + is_reasoning_content_block, + is_text_block, + is_tool_result_block, + is_tool_use_block, + is_video_block, +) from ..types.exceptions import ( ContextWindowOverflowException, ModelThrottledException, @@ -386,7 +399,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An return {"cachePoint": {"type": content["cachePoint"]["type"]}} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html - if "document" in content: + if is_document_block(content): document = content["document"] result: dict[str, Any] = {} @@ -409,14 +422,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An return {"document": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_GuardrailConverseContentBlock.html - if "guardContent" in content: + if is_guard_content_block(content): guard = content["guardContent"] guard_text = guard["text"] result = {"text": {"text": guard_text["text"], "qualifiers": guard_text["qualifiers"]}} return {"guardContent": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageBlock.html - if "image" in content: + if is_image_block(content): image = content["image"] source = image["source"] formatted_source = {} @@ -426,7 +439,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An return {"image": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html - if "reasoningContent" in content: + if is_reasoning_content_block(content): reasoning = content["reasoningContent"] result = {} @@ -445,11 +458,11 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An return {"reasoningContent": result} # Pass through text and other simple content types - if "text" in content: + if is_text_block(content): return {"text": content["text"]} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html - if "toolResult" in content: + if is_tool_result_block(content): tool_result = content["toolResult"] formatted_content: list[dict[str, Any]] = [] for tool_result_content in tool_result["content"]: @@ -470,7 +483,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An return {"toolResult": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolUseBlock.html - if "toolUse" in content: + if is_tool_use_block(content): tool_use = content["toolUse"] return { "toolUse": { @@ -481,7 +494,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An } # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_VideoBlock.html - if "video" in content: + if is_video_block(content): video = content["video"] source = video["source"] formatted_source = {} @@ -491,7 +504,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An return {"video": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html - if "citationsContent" in content: + if is_citations_block(content): citations = content["citationsContent"] result = {} @@ -777,7 +790,7 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera # Process content blocks for content in cast(list[ContentBlock], response["output"]["message"]["content"]): # Yield contentBlockStart event if needed - if "toolUse" in content: + if is_tool_use_block(content): yield { "contentBlockStart": { "start": { @@ -793,14 +806,14 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera input_value = json.dumps(content["toolUse"]["input"]) yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}} - elif "text" in content: + elif is_text_block(content): # Then yield the text as a delta yield { "contentBlockDelta": { "delta": {"text": content["text"]}, } } - elif "reasoningContent" in content: + elif is_reasoning_content_block(content): # Then yield the reasoning content as a delta yield { "contentBlockDelta": { @@ -818,7 +831,7 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera } } } - elif "citationsContent" in content: + elif is_citations_block(content): # For non-streaming citations, emit text and metadata deltas in sequence # to match streaming behavior where they flow naturally if "content" in content["citationsContent"]: diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index cf7cc604a..9a5d3e1ad 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -12,7 +12,16 @@ from google import genai from typing_extensions import Required, Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ( + ContentBlock, + Messages, + is_document_block, + is_image_block, + is_reasoning_content_block, + is_text_block, + is_tool_result_block, + is_tool_use_block, +) from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec @@ -143,7 +152,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par Returns: Gemini part. """ - if "document" in content: + if is_document_block(content): return genai.types.Part( inline_data=genai.types.Blob( data=content["document"]["source"]["bytes"], @@ -151,7 +160,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par ), ) - if "image" in content: + if is_image_block(content): return genai.types.Part( inline_data=genai.types.Blob( data=content["image"]["source"]["bytes"], @@ -159,7 +168,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par ), ) - if "reasoningContent" in content: + if is_reasoning_content_block(content): thought_signature = content["reasoningContent"]["reasoningText"].get("signature") return genai.types.Part( @@ -168,10 +177,10 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par thought_signature=thought_signature.encode("utf-8") if thought_signature else None, ) - if "text" in content: + if is_text_block(content): return genai.types.Part(text=content["text"]) - if "toolResult" in content: + if is_tool_result_block(content): return genai.types.Part( function_response=genai.types.FunctionResponse( id=content["toolResult"]["toolUseId"], @@ -189,7 +198,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par ), ) - if "toolUse" in content: + if is_tool_use_block(content): return genai.types.Part( function_call=genai.types.FunctionCall( args=content["toolUse"]["input"], diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 1f1e999d2..36545b8f5 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -14,7 +14,7 @@ from typing_extensions import Unpack, override from ..tools import convert_pydantic_to_tool_spec -from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.content import ContentBlock, Messages, SystemContentBlock, is_reasoning_content_block, is_video_block from ..types.event_loop import Usage from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import MetadataEvent, StreamEvent @@ -95,14 +95,14 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> Raises: TypeError: If the content block type cannot be converted to a LiteLLM-compatible format. """ - if "reasoningContent" in content: + if is_reasoning_content_block(content): return { "signature": content["reasoningContent"]["reasoningText"]["signature"], "thinking": content["reasoningContent"]["reasoningText"]["text"], "type": "thinking", } - if "video" in content: + if is_video_block(content): return { "type": "video_url", "video_url": { diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 013cd2c7d..46ff42d94 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -15,7 +15,14 @@ from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ( + ContentBlock, + Messages, + is_image_block, + is_text_block, + is_tool_result_block, + is_tool_use_block, +) from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent, Usage from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse @@ -103,7 +110,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Raises: TypeError: If the content block type cannot be converted to a LlamaAPI-compatible format. """ - if "image" in content: + if is_image_block(content): mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") @@ -114,7 +121,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An "type": "image_url", } - if "text" in content: + if is_text_block(content): return {"text": content["text"], "type": "text"} raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @@ -179,17 +186,17 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s formatted_contents = [ self._format_request_message_content(content) for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + if not (is_tool_result_block(content) or is_tool_use_block(content)) ] formatted_tool_calls = [ self._format_request_message_tool_call(content["toolUse"]) for content in contents - if "toolUse" in content + if is_tool_use_block(content) ] formatted_tool_messages = [ self._format_request_tool_message(content["toolResult"]) for content in contents - if "toolResult" in content + if is_tool_result_block(content) ] if message["role"] == "assistant": diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 22a3a3873..93bc3ed90 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -30,7 +30,15 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ( + ContentBlock, + Messages, + is_document_block, + is_image_block, + is_text_block, + is_tool_result_block, + is_tool_use_block, +) from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec @@ -208,20 +216,22 @@ def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) Raises: TypeError: If the content block type cannot be converted to a compatible format. """ - if "document" in content: - mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") - file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + if is_document_block(content): + doc = content["document"] + mime_type = mimetypes.types_map.get(f".{doc['format']}", "application/octet-stream") + file_data = base64.b64encode(doc["source"]["bytes"]).decode("utf-8") return { "file": { "file_data": f"data:{mime_type};base64,{file_data}", - "filename": content["document"]["name"], + "filename": doc["name"], }, "type": "file", } - if "image" in content: - mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + if is_image_block(content): + img = content["image"] + mime_type = mimetypes.types_map.get(f".{img['format']}", "application/octet-stream") + image_data = base64.b64encode(img["source"]["bytes"]).decode("utf-8") return { "image_url": { "detail": "auto", @@ -232,16 +242,17 @@ def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) } # Handle audio content (not in standard ContentBlock but supported by llama.cpp) - if "audio" in content: - audio_content = cast(Dict[str, Any], content) - audio_data = base64.b64encode(audio_content["audio"]["source"]["bytes"]).decode("utf-8") - audio_format = audio_content["audio"].get("format", "wav") + # Audio is a llama.cpp-specific extension, not part of the standard ContentBlock type + content_dict = cast(Dict[str, Any], content) + if "audio" in content_dict: + audio_data = base64.b64encode(content_dict["audio"]["source"]["bytes"]).decode("utf-8") + audio_format = content_dict["audio"].get("format", "wav") return { "type": "input_audio", "input_audio": {"data": audio_data, "format": audio_format}, } - if "text" in content: + if is_text_block(content): return {"text": content["text"], "type": "text"} raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @@ -306,7 +317,7 @@ def _format_messages(self, messages: Messages, system_prompt: Optional[str] = No formatted_contents = [ self._format_message_content(content) for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + if not is_tool_result_block(content) and not is_tool_use_block(content) ] formatted_tool_calls = [ self._format_tool_call( @@ -317,7 +328,7 @@ def _format_messages(self, messages: Messages, system_prompt: Optional[str] = No } ) for content in contents - if "toolUse" in content + if is_tool_use_block(content) ] formatted_tool_messages = [ self._format_tool_message( @@ -327,7 +338,7 @@ def _format_messages(self, messages: Messages, system_prompt: Optional[str] = No } ) for content in contents - if "toolResult" in content + if is_tool_result_block(content) ] formatted_message = { diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index b6459d63f..c49d1a975 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -12,7 +12,14 @@ from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ( + ContentBlock, + Messages, + is_image_block, + is_text_block, + is_tool_result_block, + is_tool_use_block, +) from ..types.exceptions import ModelThrottledException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse @@ -127,10 +134,10 @@ def _format_request_message_content(self, content: ContentBlock) -> Union[str, d Raises: TypeError: If the content block type cannot be converted to a Mistral-compatible format. """ - if "text" in content: + if is_text_block(content): return content["text"] - if "image" in content: + if is_image_block(content): image_data = content["image"] if "source" in image_data: @@ -211,13 +218,13 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s tool_messages: list[dict[str, Any]] = [] for content in contents: - if "text" in content: + if is_text_block(content): formatted_content = self._format_request_message_content(content) if isinstance(formatted_content, str): text_contents.append(formatted_content) - elif "toolUse" in content: + elif is_tool_use_block(content): tool_calls.append(self._format_request_message_tool_call(content["toolUse"])) - elif "toolResult" in content: + elif is_tool_result_block(content): tool_messages.append(self._format_request_tool_message(content["toolResult"])) if text_contents or tool_calls: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 574b24200..1c50ffeb1 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -11,7 +11,14 @@ from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ( + ContentBlock, + Messages, + is_image_block, + is_text_block, + is_tool_result_block, + is_tool_use_block, +) from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys, warn_on_tool_choice_not_supported @@ -110,13 +117,13 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) -> Raises: TypeError: If the content block type cannot be converted to an Ollama-compatible format. """ - if "text" in content: + if is_text_block(content): return [{"role": role, "content": content["text"]}] - if "image" in content: + if is_image_block(content): return [{"role": role, "images": [content["image"]["source"]["bytes"]]}] - if "toolUse" in content: + if is_tool_use_block(content): return [ { "role": role, @@ -131,7 +138,7 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) -> } ] - if "toolResult" in content: + if is_tool_result_block(content): return [ formatted_tool_result_content for tool_result_content in content["toolResult"]["content"] diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 07246c5d6..cc1e2a570 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -15,7 +15,17 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.content import ( + ContentBlock, + Messages, + SystemContentBlock, + is_document_block, + is_image_block, + is_reasoning_content_block, + is_text_block, + is_tool_result_block, + is_tool_use_block, +) from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse @@ -126,7 +136,7 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> Raises: TypeError: If the content block type cannot be converted to an OpenAI-compatible format. """ - if "document" in content: + if is_document_block(content): mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") return { @@ -137,7 +147,7 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> "type": "file", } - if "image" in content: + if is_image_block(content): mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") @@ -150,7 +160,7 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> "type": "image_url", } - if "text" in content: + if is_text_block(content): return {"text": content["text"], "type": "text"} raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @@ -270,7 +280,7 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic contents = message["content"] # Check for reasoningContent and warn user - if any("reasoningContent" in content for content in contents): + if any(is_reasoning_content_block(content) for content in contents): logger.warning( "reasoningContent is not supported in multi-turn conversations with the Chat Completions API." ) @@ -278,15 +288,19 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic formatted_contents = [ cls.format_request_message_content(content) for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]) + if not ( + is_tool_result_block(content) or is_tool_use_block(content) or is_reasoning_content_block(content) + ) ] formatted_tool_calls = [ - cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content + cls.format_request_message_tool_call(content["toolUse"]) + for content in contents + if is_tool_use_block(content) ] formatted_tool_messages = [ cls.format_request_tool_message(content["toolResult"]) for content in contents - if "toolResult" in content + if is_tool_result_block(content) ] formatted_message = { diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 1fe630fdc..67ce9054a 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -4,7 +4,7 @@ import logging import os from dataclasses import dataclass -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union +from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -12,7 +12,7 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, is_reasoning_content_block, is_video_block from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec from ._validation import validate_config_keys, warn_on_tool_choice_not_supported @@ -550,16 +550,19 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> # if "text" in content and not isinstance(content["text"], str): # return {"type": "text", "text": str(content["text"])} - if "reasoningContent" in content and content["reasoningContent"]: + if is_reasoning_content_block(content) and content["reasoningContent"]: return { "signature": content["reasoningContent"].get("reasoningText", {}).get("signature", ""), "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""), "type": "thinking", } - elif not content.get("reasoningContent"): - content.pop("reasoningContent", None) + elif is_reasoning_content_block(content) and not content.get("reasoningContent"): + # Cast to dict for mutation before passing to parent + mutable_content = cast(dict[str, Any], dict(content)) + mutable_content.pop("reasoningContent", None) + return super().format_request_message_content(cast(ContentBlock, mutable_content)) - if "video" in content: + if is_video_block(content): return { "type": "video_url", "video_url": { diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index a54fc44c3..02ce1c80e 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -13,7 +13,13 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ( + ContentBlock, + Messages, + is_text_block, + is_tool_result_block, + is_tool_use_block, +) from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse @@ -96,12 +102,15 @@ def _format_content_vision(content: ContentBlock) -> dict[str, Any]: Raises: TypeError: If the content block type cannot be converted to a Writer-compatible format. """ - if "text" in content: - return {"text": content["text"], "type": "text"} + # Cast to Dict for runtime key checking + content_dict = cast(Dict[str, Any], content) - if "image" in content: - mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + if "text" in content_dict: + return {"text": content_dict["text"], "type": "text"} + + if "image" in content_dict: + mime_type = mimetypes.types_map.get(f".{content_dict['image']['format']}", "application/octet-stream") + image_data = base64.b64encode(content_dict["image"]["source"]["bytes"]).decode("utf-8") return { "image_url": { @@ -110,7 +119,7 @@ def _format_content_vision(content: ContentBlock) -> dict[str, Any]: "type": "image_url", } - raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + raise TypeError(f"content_type=<{next(iter(content_dict))}> | unsupported type") return [ _format_content_vision(content) @@ -133,7 +142,7 @@ def _format_content(content: ContentBlock) -> str: Raises: TypeError: If the content block type cannot be converted to a Writer-compatible format. """ - if "text" in content: + if is_text_block(content): return content["text"] raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @@ -226,12 +235,12 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s formatted_tool_calls = [ self._format_request_message_tool_call(content["toolUse"]) for content in contents - if "toolUse" in content + if is_tool_use_block(content) ] formatted_tool_messages = [ self._format_request_tool_message(content["toolResult"]) for content in contents - if "toolResult" in content + if is_tool_result_block(content) ] formatted_message = { diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 52b6d2ef1..de940a567 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -23,7 +23,13 @@ from ...agent.agent import Agent as SAAgent from ...agent.agent import AgentResult as SAAgentResult -from ...types.content import ContentBlock +from ...types.content import ( + ContentBlock, + ContentBlockDocument, + ContentBlockImage, + ContentBlockText, + ContentBlockVideo, +) from ...types.media import ( DocumentContent, DocumentSource, @@ -259,7 +265,7 @@ def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[Conten if isinstance(part_root, TextPart): # Handle TextPart - content_blocks.append(ContentBlock(text=part_root.text)) + content_blocks.append(ContentBlockText(text=part_root.text)) elif isinstance(part_root, FilePart): # Handle FilePart @@ -283,7 +289,7 @@ def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[Conten if file_type == "image": content_blocks.append( - ContentBlock( + ContentBlockImage( image=ImageContent( format=file_format, # type: ignore source=ImageSource(bytes=decoded_bytes), @@ -292,7 +298,7 @@ def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[Conten ) elif file_type == "video": content_blocks.append( - ContentBlock( + ContentBlockVideo( video=VideoContent( format=file_format, # type: ignore source=VideoSource(bytes=decoded_bytes), @@ -301,7 +307,7 @@ def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[Conten ) else: # document or unknown content_blocks.append( - ContentBlock( + ContentBlockDocument( document=DocumentContent( format=file_format, # type: ignore name=file_name, @@ -313,7 +319,7 @@ def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[Conten elif uri_data: # For URI files, create a text representation since Strands ContentBlocks expect bytes content_blocks.append( - ContentBlock( + ContentBlockText( text="[File: %s (%s)] - Referenced file at: %s" % (file_name, mime_type, uri_data) ) ) @@ -321,7 +327,7 @@ def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[Conten # Handle DataPart - convert structured data to JSON text try: data_text = json.dumps(part_root.data, indent=2) - content_blocks.append(ContentBlock(text="[Structured Data]\n%s" % data_text)) + content_blocks.append(ContentBlockText(text="[Structured Data]\n%s" % data_text)) except Exception: logger.exception("Failed to serialize data part") except Exception: diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 6156d332c..5230b7762 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -44,7 +44,7 @@ MultiAgentNodeStreamEvent, MultiAgentResultEvent, ) -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, ContentBlockText, Messages from ..types.event_loop import Metrics, Usage from ..types.multiagent import MultiAgentInput from ..types.traces import AttributeValue @@ -977,32 +977,32 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: if not dependency_results: # No dependencies - return task as ContentBlocks if isinstance(self.state.task, str): - return [ContentBlock(text=self.state.task)] + return [ContentBlockText(text=self.state.task)] else: return cast(list[ContentBlock], self.state.task) # Combine task with dependency outputs - node_input = [] + node_input: list[ContentBlock] = [] # Add original task if isinstance(self.state.task, str): - node_input.append(ContentBlock(text=f"Original Task: {self.state.task}")) + node_input.append(ContentBlockText(text=f"Original Task: {self.state.task}")) else: # Add task content blocks with a prefix - node_input.append(ContentBlock(text="Original Task:")) + node_input.append(ContentBlockText(text="Original Task:")) node_input.extend(cast(list[ContentBlock], self.state.task)) # Add dependency outputs - node_input.append(ContentBlock(text="\nInputs from previous nodes:")) + node_input.append(ContentBlockText(text="\nInputs from previous nodes:")) for dep_id, node_result in dependency_results.items(): - node_input.append(ContentBlock(text=f"\nFrom {dep_id}:")) + node_input.append(ContentBlockText(text=f"\nFrom {dep_id}:")) # Get all agent results from this node (flattened if nested) agent_results = node_result.get_agent_results() for result in agent_results: agent_name = getattr(result, "agent_name", "Agent") result_text = str(result) - node_input.append(ContentBlock(text=f" - {agent_name}: {result_text}")) + node_input.append(ContentBlockText(text=f" - {agent_name}: {result_text}")) return node_input diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 7eec49649..453683f8c 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -47,7 +47,7 @@ MultiAgentNodeStreamEvent, MultiAgentResultEvent, ) -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, ContentBlockText, Messages from ..types.event_loop import Metrics, Usage from ..types.multiagent import MultiAgentInput from ..types.traces import AttributeValue @@ -842,7 +842,7 @@ async def _execute_node( else: # Prepare context for node context_text = self._build_node_input(node) - node_input = [ContentBlock(text=f"Context:\n{context_text}\n\n")] + node_input = [ContentBlockText(text=f"Context:\n{context_text}\n\n")] # Clear handoff message after it's been included in context self.state.handoff_message = None diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index a8ac099d9..312789fad 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -5,7 +5,7 @@ from ..agent.state import AgentState from ..tools._tool_helpers import generate_missing_tool_result_content -from ..types.content import Message +from ..types.content import Message, is_tool_result_block, is_tool_use_block from ..types.exceptions import SessionException from ..types.session import ( Session, @@ -198,16 +198,16 @@ def _fix_broken_tool_use(self, messages: list[Message]) -> list[Message]: # Check all but the latest message in the messages array # The latest message being orphaned is handled in the agent class if index + 1 < len(messages): - if any("toolUse" in content for content in message["content"]): + if any(is_tool_use_block(content) for content in message["content"]): tool_use_ids = [ - content["toolUse"]["toolUseId"] for content in message["content"] if "toolUse" in content + content["toolUse"]["toolUseId"] for content in message["content"] if is_tool_use_block(content) ] # Check if there are more messages after the current toolUse message tool_result_ids = [ content["toolResult"]["toolUseId"] for content in messages[index + 1]["content"] - if "toolResult" in content + if is_tool_result_block(content) ] missing_tool_use_ids = list(set(tool_use_ids) - set(tool_result_ids)) diff --git a/src/strands/tools/_validator.py b/src/strands/tools/_validator.py index 839d6d910..a23744f5f 100644 --- a/src/strands/tools/_validator.py +++ b/src/strands/tools/_validator.py @@ -1,7 +1,7 @@ """Tool validation utilities.""" from ..tools.tools import InvalidToolUseNameException, validate_tool_use -from ..types.content import Message +from ..types.content import Message, is_tool_use_block from ..types.tools import ToolResult, ToolUse @@ -21,7 +21,7 @@ def validate_and_prepare_tools( """ # Extract tool uses from message for content in message["content"]: - if isinstance(content, dict) and "toolUse" in content: + if is_tool_use_block(content): tool_uses.append(content["toolUse"]) # Validate tool uses diff --git a/src/strands/types/__init__.py b/src/strands/types/__init__.py index 7eef60cb4..fea2c6e46 100644 --- a/src/strands/types/__init__.py +++ b/src/strands/types/__init__.py @@ -1,5 +1,6 @@ """SDK type definitions.""" from .collections import PaginatedList +from .content import CONTENT_BLOCK_KEYS, ContentBlockText -__all__ = ["PaginatedList"] +__all__ = ["PaginatedList", "ContentBlockText", "CONTENT_BLOCK_KEYS"] diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 4d0bbe412..bd1c71da6 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -6,9 +6,9 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union -from typing_extensions import TypedDict +from typing_extensions import NotRequired, TypedDict, TypeGuard from .citations import CitationsContentBlock from .media import DocumentContent, ImageContent, VideoContent @@ -71,32 +71,274 @@ class CachePoint(TypedDict): type: str -class ContentBlock(TypedDict, total=False): - """A block of content for a message that you pass to, or receive from, a model. +class ContentBlockText(TypedDict): + """A content block containing text. Attributes: + text: Text to include in the message. cachePoint: A cache point configuration to optimize conversation history. - document: A document to include in the message. - guardContent: Contains the content to assess with the guardrail. + """ + + text: str + cachePoint: NotRequired[CachePoint] + + +class ContentBlockImage(TypedDict): + """A content block containing an image. + + Attributes: image: Image to include in the message. - reasoningContent: Contains content regarding the reasoning that is carried out by the model. - text: Text to include in the message. - toolResult: The result for a tool request that a model makes. - toolUse: Information about a tool use request from a model. - video: Video to include in the message. - citationsContent: Contains the citations for a document. + cachePoint: A cache point configuration to optimize conversation history. + """ + + image: ImageContent + cachePoint: NotRequired[CachePoint] + + +class ContentBlockDocument(TypedDict): + """A content block containing a document. + + Attributes: + document: A document to include in the message. + cachePoint: A cache point configuration to optimize conversation history. """ - cachePoint: CachePoint document: DocumentContent + cachePoint: NotRequired[CachePoint] + + +class ContentBlockVideo(TypedDict): + """A content block containing a video. + + Attributes: + video: Video to include in the message. + cachePoint: A cache point configuration to optimize conversation history. + """ + + video: VideoContent + cachePoint: NotRequired[CachePoint] + + +class ContentBlockToolUse(TypedDict): + """A content block containing a tool use request. + + Attributes: + toolUse: Information about a tool use request from a model. + cachePoint: A cache point configuration to optimize conversation history. + """ + + toolUse: ToolUse + cachePoint: NotRequired[CachePoint] + + +class ContentBlockToolResult(TypedDict): + """A content block containing a tool result. + + Attributes: + toolResult: The result for a tool request that a model makes. + cachePoint: A cache point configuration to optimize conversation history. + """ + + toolResult: ToolResult + cachePoint: NotRequired[CachePoint] + + +class ContentBlockGuardContent(TypedDict): + """A content block containing content to be evaluated by guardrails. + + Attributes: + guardContent: Contains the content to assess with the guardrail. + cachePoint: A cache point configuration to optimize conversation history. + """ + guardContent: GuardContent - image: ImageContent + cachePoint: NotRequired[CachePoint] + + +class ContentBlockReasoningContent(TypedDict): + """A content block containing reasoning content. + + Attributes: + reasoningContent: Contains content regarding the reasoning that is carried out by the model. + cachePoint: A cache point configuration to optimize conversation history. + """ + reasoningContent: ReasoningContentBlock - text: str - toolResult: ToolResult - toolUse: ToolUse - video: VideoContent + cachePoint: NotRequired[CachePoint] + + +class ContentBlockCitations(TypedDict): + """A content block containing citations. + + Attributes: + citationsContent: Contains the citations for a document. + cachePoint: A cache point configuration to optimize conversation history. + """ + citationsContent: CitationsContentBlock + cachePoint: NotRequired[CachePoint] + + +ContentBlock = Union[ + ContentBlockText, + ContentBlockImage, + ContentBlockDocument, + ContentBlockVideo, + ContentBlockToolUse, + ContentBlockToolResult, + ContentBlockGuardContent, + ContentBlockReasoningContent, + ContentBlockCitations, +] +"""A block of content for a message that you pass to, or receive from, a model. + +This is a union type where each variant contains exactly one type of content (text, image, document, video, +toolUse, toolResult, guardContent, reasoningContent, or citationsContent). Each variant may optionally include +a cachePoint configuration. + +Based on the Bedrock API specification, a ContentBlock must contain one and only one of the content types. + +For constructing content blocks, use the specific types: +- ContentBlockText(text="...") +- ContentBlockImage(image=...) +- ContentBlockDocument(document=...) +- ContentBlockVideo(video=...) +- ContentBlockToolUse(toolUse=...) +- ContentBlockToolResult(toolResult=...) +- ContentBlockGuardContent(guardContent=...) +- ContentBlockReasoningContent(reasoningContent=...) +- ContentBlockCitations(citationsContent=...) +""" + + +# Type alias for content block input that accepts both typed and untyped dicts +ContentBlockInput = Union[ContentBlock, Dict[str, Any]] + + +# Set of all valid keys that can appear in a ContentBlock +CONTENT_BLOCK_KEYS = frozenset( + { + "text", + "image", + "document", + "video", + "toolUse", + "toolResult", + "guardContent", + "reasoningContent", + "citationsContent", + "cachePoint", + } +) + + +# Type guard functions for narrowing ContentBlock union types +def is_text_block(content: ContentBlockInput) -> TypeGuard[ContentBlockText]: + """Check if a content block is a text block. + + Args: + content: The content block to check. + + Returns: + True if the content block contains text content. + """ + return isinstance(content, dict) and "text" in content + + +def is_image_block(content: ContentBlockInput) -> TypeGuard[ContentBlockImage]: + """Check if a content block is an image block. + + Args: + content: The content block to check. + + Returns: + True if the content block contains image content. + """ + return isinstance(content, dict) and "image" in content + + +def is_document_block(content: ContentBlockInput) -> TypeGuard[ContentBlockDocument]: + """Check if a content block is a document block. + + Args: + content: The content block to check. + + Returns: + True if the content block contains document content. + """ + return isinstance(content, dict) and "document" in content + + +def is_video_block(content: ContentBlockInput) -> TypeGuard[ContentBlockVideo]: + """Check if a content block is a video block. + + Args: + content: The content block to check. + + Returns: + True if the content block contains video content. + """ + return isinstance(content, dict) and "video" in content + + +def is_tool_use_block(content: ContentBlockInput) -> TypeGuard[ContentBlockToolUse]: + """Check if a content block is a tool use block. + + Args: + content: The content block to check. + + Returns: + True if the content block contains tool use content. + """ + return isinstance(content, dict) and "toolUse" in content + + +def is_tool_result_block(content: ContentBlockInput) -> TypeGuard[ContentBlockToolResult]: + """Check if a content block is a tool result block. + + Args: + content: The content block to check. + + Returns: + True if the content block contains tool result content. + """ + return isinstance(content, dict) and "toolResult" in content + + +def is_guard_content_block(content: ContentBlockInput) -> TypeGuard[ContentBlockGuardContent]: + """Check if a content block is a guard content block. + + Args: + content: The content block to check. + + Returns: + True if the content block contains guard content. + """ + return isinstance(content, dict) and "guardContent" in content + + +def is_reasoning_content_block(content: ContentBlockInput) -> TypeGuard[ContentBlockReasoningContent]: + """Check if a content block is a reasoning content block. + + Args: + content: The content block to check. + + Returns: + True if the content block contains reasoning content. + """ + return isinstance(content, dict) and "reasoningContent" in content + + +def is_citations_block(content: ContentBlockInput) -> TypeGuard[ContentBlockCitations]: + """Check if a content block is a citations block. + + Args: + content: The content block to check. + + Returns: + True if the content block contains citations content. + """ + return isinstance(content, dict) and "citationsContent" in content class SystemContentBlock(TypedDict, total=False): diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 1463d3f48..c96d05cce 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -9,7 +9,7 @@ from strands.agent.agent_result import AgentResult as SAAgentResult from strands.multiagent.a2a.executor import StrandsA2AExecutor -from strands.types.content import ContentBlock +from strands.types.content import ContentBlockText # Test data constants VALID_PNG_BYTES = b"fake_png_data" @@ -97,7 +97,7 @@ def test_convert_a2a_parts_to_content_blocks_text_part(): result = executor._convert_a2a_parts_to_content_blocks([part]) assert len(result) == 1 - assert result[0] == ContentBlock(text="Hello, world!") + assert result[0] == ContentBlockText(text="Hello, world!") def test_convert_a2a_parts_to_content_blocks_file_part_image_bytes(): diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index f2abed9f7..205fc0cb8 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -14,7 +14,7 @@ from strands.session.file_session_manager import FileSessionManager from strands.session.session_manager import SessionManager from strands.types._events import MultiAgentNodeStartEvent -from strands.types.content import ContentBlock +from strands.types.content import ContentBlockText def create_mock_agent(name, response_text="Default response", metrics=None, agent_id=None, should_fail=False): @@ -236,7 +236,7 @@ def test_swarm_state_should_continue(mock_swarm): async def test_swarm_execution_async(mock_strands_tracer, mock_use_span, mock_swarm, mock_agents): """Test asynchronous swarm execution.""" # Execute swarm - task = [ContentBlock(text="Analyze this task"), ContentBlock(text="Additional context")] + task = [ContentBlockText(text="Analyze this task"), ContentBlockText(text="Additional context")] result = await mock_swarm.invoke_async(task) # Verify execution results diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index 7e28be998..e37f04f38 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -9,7 +9,7 @@ from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.session.file_session_manager import FileSessionManager -from strands.types.content import ContentBlock +from strands.types.content import ContentBlockText from strands.types.exceptions import SessionException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType @@ -47,7 +47,7 @@ def sample_message(): return SessionMessage.from_message( message={ "role": "user", - "content": [ContentBlock(text="Hello world")], + "content": [ContentBlockText(text="Hello world")], }, index=0, ) @@ -263,7 +263,7 @@ def test_list_messages_all(file_manager, sample_session, sample_agent): message = SessionMessage( message={ "role": "user", - "content": [ContentBlock(text=f"Message {i}")], + "content": [ContentBlockText(text=f"Message {i}")], }, message_id=i, ) @@ -287,7 +287,7 @@ def test_list_messages_with_limit(file_manager, sample_session, sample_agent): message = SessionMessage( message={ "role": "user", - "content": [ContentBlock(text=f"Message {i}")], + "content": [ContentBlockText(text=f"Message {i}")], }, message_id=i, ) @@ -310,7 +310,7 @@ def test_list_messages_with_offset(file_manager, sample_session, sample_agent): message = SessionMessage( message={ "role": "user", - "content": [ContentBlock(text=f"Message {i}")], + "content": [ContentBlockText(text=f"Message {i}")], }, message_id=i, ) @@ -341,7 +341,7 @@ def test_update_message(file_manager, sample_session, sample_agent, sample_messa file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) # Update message - sample_message.message["content"] = [ContentBlock(text="Updated content")] + sample_message.message["content"] = [ContentBlockText(text="Updated content")] file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) # Verify update diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 22de9f964..8624fc4d2 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -10,7 +10,7 @@ from strands.agent.state import AgentState from strands.interrupt import _InterruptState from strands.session.repository_session_manager import RepositorySessionManager -from strands.types.content import ContentBlock +from strands.types.content import ContentBlockText from strands.types.exceptions import SessionException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository @@ -118,7 +118,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): message = SessionMessage( message={ "role": "user", - "content": [ContentBlock(text="Hello")], + "content": [ContentBlockText(text="Hello")], }, message_id=0, ) @@ -153,7 +153,7 @@ def test_initialize_restores_existing_agent_with_summarizing_conversation_manage message = SessionMessage( message={ "role": "user", - "content": [ContentBlock(text="Hello")], + "content": [ContentBlockText(text="Hello")], }, message_id=0, ) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 719fbc2c9..f50c42732 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -11,7 +11,7 @@ from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.session.s3_session_manager import S3SessionManager -from strands.types.content import ContentBlock +from strands.types.content import ContentBlockText from strands.types.exceptions import SessionException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType @@ -66,7 +66,7 @@ def sample_message(): return SessionMessage.from_message( message={ "role": "user", - "content": [ContentBlock(text="test_message")], + "content": [ContentBlockText(text="test_message")], }, index=0, ) @@ -269,7 +269,7 @@ def test_list_messages_all(s3_manager, sample_session, sample_agent): message = SessionMessage( { "role": "user", - "content": [ContentBlock(text=f"Message {i}")], + "content": [ContentBlockText(text=f"Message {i}")], }, i, ) @@ -293,7 +293,7 @@ def test_list_messages_with_pagination(s3_manager, sample_session, sample_agent) message = SessionMessage.from_message( message={ "role": "user", - "content": [ContentBlock(text="test_message")], + "content": [ContentBlockText(text="test_message")], }, index=index, ) @@ -316,7 +316,7 @@ def test_update_message(s3_manager, sample_session, sample_agent, sample_message s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) # Update message - sample_message.message["content"] = [ContentBlock(text="Updated content")] + sample_message.message["content"] = [ContentBlockText(text="Updated content")] s3_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) # Verify update diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index cb98b8130..5ac41cd50 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -10,7 +10,7 @@ ) from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize -from strands.types.content import ContentBlock +from strands.types.content import ContentBlockText from strands.types.interrupt import InterruptResponseContent from strands.types.streaming import Metrics, StopReason, Usage @@ -393,7 +393,7 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): mock_span = mock.MagicMock() mock_tracer.start_span.return_value = mock_span - task = [ContentBlock(text="Original Task: foo bar")] + task = [ContentBlockText(text="Original Task: foo bar")] span = tracer.start_multiagent_span(task, "swarm") @@ -411,7 +411,7 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): @pytest.mark.parametrize( "task, expected_parts", [ - ([ContentBlock(text="Test message")], [{"type": "text", "content": "Test message"}]), + ([ContentBlockText(text="Test message")], [{"type": "text", "content": "Test message"}]), ( [InterruptResponseContent(interruptResponse={"interruptId": "test-id", "response": "approved"})], [{"type": "interrupt_response", "id": "test-id", "response": "approved"}], @@ -446,7 +446,7 @@ def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, mock_span = mock.MagicMock() mock_tracer.start_span.return_value = mock_span - task = [ContentBlock(text="Original Task: foo bar")] + task = [ContentBlockText(text="Original Task: foo bar")] span = tracer.start_multiagent_span(task, "swarm")