Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 68 additions & 29 deletions src/strands/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- Docs: https://ai.google.dev/api
"""

import base64
import json
import logging
import mimetypes
Expand Down Expand Up @@ -165,7 +166,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par
return genai.types.Part(
text=content["reasoningContent"]["reasoningText"]["text"],
thought=True,
thought_signature=thought_signature.encode("utf-8") if thought_signature else None,
thought_signature=base64.b64decode(thought_signature) if thought_signature else None,
)

if "text" in content:
Expand All @@ -190,13 +191,26 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par
)

if "toolUse" in content:
return genai.types.Part(
function_call=genai.types.FunctionCall(
args=content["toolUse"]["input"],
id=content["toolUse"]["toolUseId"],
name=content["toolUse"]["name"],
),
)
# Get thought_signature if present (for Gemini thinking models)
thought_signature = content["toolUse"].get("thoughtSignature")

# For Gemini thinking models, function calls require thought_signature.
# If missing (e.g., from old session history), convert to text representation
# to preserve context without triggering API errors.
if thought_signature:
return genai.types.Part(
function_call=genai.types.FunctionCall(
args=content["toolUse"]["input"],
id=content["toolUse"]["toolUseId"],
name=content["toolUse"]["name"],
),
thought_signature=base64.b64decode(thought_signature),
)
else:
# Convert to text representation for backwards compatibility
tool_name = content["toolUse"]["name"]
tool_input = content["toolUse"]["input"]
return genai.types.Part(text=f"[Called tool: {tool_name} with input: {json.dumps(tool_input)}]")

raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")

Expand Down Expand Up @@ -230,20 +244,27 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge
Return:
Gemini tool list.
"""
tools = [
genai.types.Tool(
function_declarations=[
genai.types.FunctionDeclaration(
description=tool_spec["description"],
name=tool_spec["name"],
parameters_json_schema=tool_spec["inputSchema"]["json"],
)
for tool_spec in tool_specs or []
],
),
]
tools = []

# Only add function declarations tool if there are tool specs
if tool_specs:
tools.append(
genai.types.Tool(
function_declarations=[
genai.types.FunctionDeclaration(
description=tool_spec["description"],
name=tool_spec["name"],
parameters_json_schema=tool_spec["inputSchema"]["json"],
)
for tool_spec in tool_specs
],
),
)

# Add any Gemini-specific tools
if self.config.get("gemini_tools"):
tools.extend(self.config["gemini_tools"])

return tools

def _format_request_config(
Expand All @@ -264,11 +285,19 @@ def _format_request_config(
Returns:
Gemini request config.
"""
return genai.types.GenerateContentConfig(
system_instruction=system_prompt,
tools=self._format_request_tools(tool_specs),
tools = self._format_request_tools(tool_specs)

# Build config kwargs, only including tools if there are any
config_kwargs = {
"system_instruction": system_prompt,
**(params or {}),
)
}

# Only include tools parameter if there are actual tools to pass
if tools:
config_kwargs["tools"] = tools

return genai.types.GenerateContentConfig(**config_kwargs)

def _format_request(
self,
Expand Down Expand Up @@ -320,13 +349,19 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent:
# that name be set in the equivalent FunctionResponse type. Consequently, we assign
# function name to toolUseId in our tool use block. And another reason, function_call is
# not guaranteed to have id populated.
tool_use_data: dict[str, Any] = {
"name": event["data"].function_call.name,
"toolUseId": event["data"].function_call.name,
}
# Capture thought_signature for Gemini thinking models (base64 encoded)
if event["data"].thought_signature:
tool_use_data["thoughtSignature"] = base64.b64encode(
event["data"].thought_signature
).decode("ascii")
return {
"contentBlockStart": {
"start": {
"toolUse": {
"name": event["data"].function_call.name,
"toolUseId": event["data"].function_call.name,
},
"toolUse": tool_use_data,
},
},
}
Expand All @@ -350,7 +385,11 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"reasoningContent": {
"text": event["data"].text,
**(
{"signature": event["data"].thought_signature.decode("utf-8")}
{
"signature": base64.b64encode(event["data"].thought_signature).decode(
"ascii"
)
}
if event["data"].thought_signature
else {}
),
Expand Down
78 changes: 53 additions & 25 deletions tests/strands/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def test_stream_request_default(gemini_client, model, messages, model_id):
await anext(model.stream(messages))

exp_request = {
"config": {"tools": [{"function_declarations": []}]},
"config": {},
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
"model": model_id,
}
Expand All @@ -99,7 +99,6 @@ async def test_stream_request_with_params(gemini_client, model, messages, model_

exp_request = {
"config": {
"tools": [{"function_declarations": []}],
"temperature": 1,
},
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
Expand All @@ -113,7 +112,7 @@ async def test_stream_request_with_system_prompt(gemini_client, model, messages,
await anext(model.stream(messages, system_prompt=system_prompt))

exp_request = {
"config": {"system_instruction": system_prompt, "tools": [{"function_declarations": []}]},
"config": {"system_instruction": system_prompt},
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
"model": model_id,
}
Expand Down Expand Up @@ -146,9 +145,7 @@ async def test_stream_request_with_document(content, formatted_part, gemini_clie
await anext(model.stream(messages))

exp_request = {
"config": {
"tools": [{"function_declarations": []}],
},
"config": {},
"contents": [{"parts": [formatted_part], "role": "user"}],
"model": model_id,
}
Expand All @@ -173,9 +170,7 @@ async def test_stream_request_with_image(gemini_client, model, model_id):
await anext(model.stream(messages))

exp_request = {
"config": {
"tools": [{"function_declarations": []}],
},
"config": {},
"contents": [
{
"parts": [
Expand Down Expand Up @@ -203,7 +198,7 @@ async def test_stream_request_with_reasoning(gemini_client, model, model_id):
{
"reasoningContent": {
"reasoningText": {
"signature": "abc",
"signature": "YWJj", # base64 of "abc"
"text": "reasoning_text",
},
},
Expand All @@ -214,9 +209,7 @@ async def test_stream_request_with_reasoning(gemini_client, model, model_id):
await anext(model.stream(messages))

exp_request = {
"config": {
"tools": [{"function_declarations": []}],
},
"config": {},
"contents": [
{
"parts": [
Expand Down Expand Up @@ -260,6 +253,7 @@ async def test_stream_request_with_tool_spec(gemini_client, model, model_id, too

@pytest.mark.asyncio
async def test_stream_request_with_tool_use(gemini_client, model, model_id):
"""Test toolUse with thoughtSignature is sent as function_call."""
messages = [
{
"role": "assistant",
Expand All @@ -269,6 +263,7 @@ async def test_stream_request_with_tool_use(gemini_client, model, model_id):
"toolUseId": "c1",
"name": "calculator",
"input": {"expression": "2+2"},
"thoughtSignature": "YWJj", # base64 of "abc" - required for Gemini thinking models
},
},
],
Expand All @@ -277,9 +272,7 @@ async def test_stream_request_with_tool_use(gemini_client, model, model_id):
await anext(model.stream(messages))

exp_request = {
"config": {
"tools": [{"function_declarations": []}],
},
"config": {},
"contents": [
{
"parts": [
Expand All @@ -289,6 +282,45 @@ async def test_stream_request_with_tool_use(gemini_client, model, model_id):
"id": "c1",
"name": "calculator",
},
"thought_signature": "YWJj",
},
],
"role": "model",
},
],
"model": model_id,
}
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)


@pytest.mark.asyncio
async def test_stream_request_with_tool_use_no_thought_signature(gemini_client, model, model_id):
"""Test toolUse without thoughtSignature is converted to text for backwards compatibility."""
messages = [
{
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "c1",
"name": "calculator",
"input": {"expression": "2+2"},
# No thoughtSignature - simulates old session history
},
},
],
},
]
await anext(model.stream(messages))

# Without thoughtSignature, toolUse is converted to text representation
exp_request = {
"config": {},
"contents": [
{
"parts": [
{
"text": '[Called tool: calculator with input: {"expression": "2+2"}]',
},
],
"role": "model",
Expand Down Expand Up @@ -327,9 +359,7 @@ async def test_stream_request_with_tool_results(gemini_client, model, model_id):
await anext(model.stream(messages))

exp_request = {
"config": {
"tools": [{"function_declarations": []}],
},
"config": {},
"contents": [
{
"parts": [
Expand Down Expand Up @@ -371,9 +401,7 @@ async def test_stream_request_with_empty_content(gemini_client, model, model_id)
await anext(model.stream(messages))

exp_request = {
"config": {
"tools": [{"function_declarations": []}],
},
"config": {},
"contents": [{"parts": [], "role": "user"}],
"model": model_id,
}
Expand Down Expand Up @@ -497,10 +525,11 @@ async def test_stream_response_reasoning(gemini_client, model, messages, agenera
)

tru_chunks = await alist(model.stream(messages))
# signature is base64 encoded: b"abc" -> "YWJj"
exp_chunks = [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "abc", "text": "test reason"}}}},
{"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "YWJj", "text": "test reason"}}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "end_turn"}},
{"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}},
Expand Down Expand Up @@ -614,7 +643,6 @@ async def test_structured_output(gemini_client, model, messages, model_id, weath

exp_request = {
"config": {
"tools": [{"function_declarations": []}],
"response_mime_type": "application/json",
"response_schema": weather_output.model_json_schema(),
},
Expand Down Expand Up @@ -666,10 +694,10 @@ async def test_stream_request_with_gemini_tools(gemini_client, messages, model_i

await anext(model.stream(messages))

# When only gemini_tools are provided (no tool_specs), only gemini_tools are included
exp_request = {
"config": {
"tools": [
{"function_declarations": []},
{"google_search": {}},
]
},
Expand Down