From 1c3ff26e78b5bb0d3a05fc41abac7cfdd0196fa8 Mon Sep 17 00:00:00 2001 From: danielp Date: Mon, 22 Dec 2025 10:44:24 -0300 Subject: [PATCH] feat(gemini): add thought signature preservation for thinking models - Preserve thought_signature for function calls in Gemini thinking models - Fix empty tools list causing 400 Bad Request error - Add base64 encoding/decoding for binary thought_signature data - Add backwards compatibility for toolUse without thoughtSignature - Update tests to reflect thought signature handling Closes #1199 --- src/strands/models/gemini.py | 97 ++++++++++++++++++++--------- tests/strands/models/test_gemini.py | 78 +++++++++++++++-------- 2 files changed, 121 insertions(+), 54 deletions(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index cf7cc604a..33022450d 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -3,6 +3,7 @@ - Docs: https://ai.google.dev/api """ +import base64 import json import logging import mimetypes @@ -165,7 +166,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par return genai.types.Part( text=content["reasoningContent"]["reasoningText"]["text"], thought=True, - thought_signature=thought_signature.encode("utf-8") if thought_signature else None, + thought_signature=base64.b64decode(thought_signature) if thought_signature else None, ) if "text" in content: @@ -190,13 +191,26 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par ) if "toolUse" in content: - return genai.types.Part( - function_call=genai.types.FunctionCall( - args=content["toolUse"]["input"], - id=content["toolUse"]["toolUseId"], - name=content["toolUse"]["name"], - ), - ) + # Get thought_signature if present (for Gemini thinking models) + thought_signature = content["toolUse"].get("thoughtSignature") + + # For Gemini thinking models, function calls require thought_signature. + # If missing (e.g., from old session history), convert to text representation + # to preserve context without triggering API errors. + if thought_signature: + return genai.types.Part( + function_call=genai.types.FunctionCall( + args=content["toolUse"]["input"], + id=content["toolUse"]["toolUseId"], + name=content["toolUse"]["name"], + ), + thought_signature=base64.b64decode(thought_signature), + ) + else: + # Convert to text representation for backwards compatibility + tool_name = content["toolUse"]["name"] + tool_input = content["toolUse"]["input"] + return genai.types.Part(text=f"[Called tool: {tool_name} with input: {json.dumps(tool_input)}]") raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @@ -230,20 +244,27 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge Return: Gemini tool list. """ - tools = [ - genai.types.Tool( - function_declarations=[ - genai.types.FunctionDeclaration( - description=tool_spec["description"], - name=tool_spec["name"], - parameters_json_schema=tool_spec["inputSchema"]["json"], - ) - for tool_spec in tool_specs or [] - ], - ), - ] + tools = [] + + # Only add function declarations tool if there are tool specs + if tool_specs: + tools.append( + genai.types.Tool( + function_declarations=[ + genai.types.FunctionDeclaration( + description=tool_spec["description"], + name=tool_spec["name"], + parameters_json_schema=tool_spec["inputSchema"]["json"], + ) + for tool_spec in tool_specs + ], + ), + ) + + # Add any Gemini-specific tools if self.config.get("gemini_tools"): tools.extend(self.config["gemini_tools"]) + return tools def _format_request_config( @@ -264,11 +285,19 @@ def _format_request_config( Returns: Gemini request config. """ - return genai.types.GenerateContentConfig( - system_instruction=system_prompt, - tools=self._format_request_tools(tool_specs), + tools = self._format_request_tools(tool_specs) + + # Build config kwargs, only including tools if there are any + config_kwargs = { + "system_instruction": system_prompt, **(params or {}), - ) + } + + # Only include tools parameter if there are actual tools to pass + if tools: + config_kwargs["tools"] = tools + + return genai.types.GenerateContentConfig(**config_kwargs) def _format_request( self, @@ -320,13 +349,19 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: # that name be set in the equivalent FunctionResponse type. Consequently, we assign # function name to toolUseId in our tool use block. And another reason, function_call is # not guaranteed to have id populated. + tool_use_data: dict[str, Any] = { + "name": event["data"].function_call.name, + "toolUseId": event["data"].function_call.name, + } + # Capture thought_signature for Gemini thinking models (base64 encoded) + if event["data"].thought_signature: + tool_use_data["thoughtSignature"] = base64.b64encode( + event["data"].thought_signature + ).decode("ascii") return { "contentBlockStart": { "start": { - "toolUse": { - "name": event["data"].function_call.name, - "toolUseId": event["data"].function_call.name, - }, + "toolUse": tool_use_data, }, }, } @@ -350,7 +385,11 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: "reasoningContent": { "text": event["data"].text, **( - {"signature": event["data"].thought_signature.decode("utf-8")} + { + "signature": base64.b64encode(event["data"].thought_signature).decode( + "ascii" + ) + } if event["data"].thought_signature else {} ), diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index c552a892a..239f7a314 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -84,7 +84,7 @@ async def test_stream_request_default(gemini_client, model, messages, model_id): await anext(model.stream(messages)) exp_request = { - "config": {"tools": [{"function_declarations": []}]}, + "config": {}, "contents": [{"parts": [{"text": "test"}], "role": "user"}], "model": model_id, } @@ -99,7 +99,6 @@ async def test_stream_request_with_params(gemini_client, model, messages, model_ exp_request = { "config": { - "tools": [{"function_declarations": []}], "temperature": 1, }, "contents": [{"parts": [{"text": "test"}], "role": "user"}], @@ -113,7 +112,7 @@ async def test_stream_request_with_system_prompt(gemini_client, model, messages, await anext(model.stream(messages, system_prompt=system_prompt)) exp_request = { - "config": {"system_instruction": system_prompt, "tools": [{"function_declarations": []}]}, + "config": {"system_instruction": system_prompt}, "contents": [{"parts": [{"text": "test"}], "role": "user"}], "model": model_id, } @@ -146,9 +145,7 @@ async def test_stream_request_with_document(content, formatted_part, gemini_clie await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [{"parts": [formatted_part], "role": "user"}], "model": model_id, } @@ -173,9 +170,7 @@ async def test_stream_request_with_image(gemini_client, model, model_id): await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [ { "parts": [ @@ -203,7 +198,7 @@ async def test_stream_request_with_reasoning(gemini_client, model, model_id): { "reasoningContent": { "reasoningText": { - "signature": "abc", + "signature": "YWJj", # base64 of "abc" "text": "reasoning_text", }, }, @@ -214,9 +209,7 @@ async def test_stream_request_with_reasoning(gemini_client, model, model_id): await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [ { "parts": [ @@ -260,6 +253,7 @@ async def test_stream_request_with_tool_spec(gemini_client, model, model_id, too @pytest.mark.asyncio async def test_stream_request_with_tool_use(gemini_client, model, model_id): + """Test toolUse with thoughtSignature is sent as function_call.""" messages = [ { "role": "assistant", @@ -269,6 +263,7 @@ async def test_stream_request_with_tool_use(gemini_client, model, model_id): "toolUseId": "c1", "name": "calculator", "input": {"expression": "2+2"}, + "thoughtSignature": "YWJj", # base64 of "abc" - required for Gemini thinking models }, }, ], @@ -277,9 +272,7 @@ async def test_stream_request_with_tool_use(gemini_client, model, model_id): await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [ { "parts": [ @@ -289,6 +282,45 @@ async def test_stream_request_with_tool_use(gemini_client, model, model_id): "id": "c1", "name": "calculator", }, + "thought_signature": "YWJj", + }, + ], + "role": "model", + }, + ], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + +@pytest.mark.asyncio +async def test_stream_request_with_tool_use_no_thought_signature(gemini_client, model, model_id): + """Test toolUse without thoughtSignature is converted to text for backwards compatibility.""" + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "c1", + "name": "calculator", + "input": {"expression": "2+2"}, + # No thoughtSignature - simulates old session history + }, + }, + ], + }, + ] + await anext(model.stream(messages)) + + # Without thoughtSignature, toolUse is converted to text representation + exp_request = { + "config": {}, + "contents": [ + { + "parts": [ + { + "text": '[Called tool: calculator with input: {"expression": "2+2"}]', }, ], "role": "model", @@ -327,9 +359,7 @@ async def test_stream_request_with_tool_results(gemini_client, model, model_id): await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [ { "parts": [ @@ -371,9 +401,7 @@ async def test_stream_request_with_empty_content(gemini_client, model, model_id) await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [{"parts": [], "role": "user"}], "model": model_id, } @@ -497,10 +525,11 @@ async def test_stream_response_reasoning(gemini_client, model, messages, agenera ) tru_chunks = await alist(model.stream(messages)) + # signature is base64 encoded: b"abc" -> "YWJj" exp_chunks = [ {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "abc", "text": "test reason"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "YWJj", "text": "test reason"}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "end_turn"}}, {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, @@ -614,7 +643,6 @@ async def test_structured_output(gemini_client, model, messages, model_id, weath exp_request = { "config": { - "tools": [{"function_declarations": []}], "response_mime_type": "application/json", "response_schema": weather_output.model_json_schema(), }, @@ -666,10 +694,10 @@ async def test_stream_request_with_gemini_tools(gemini_client, messages, model_i await anext(model.stream(messages)) + # When only gemini_tools are provided (no tool_specs), only gemini_tools are included exp_request = { "config": { "tools": [ - {"function_declarations": []}, {"google_search": {}}, ] },