Skip to content

Commit 516790f

Browse files
committed
refactor: inline request/response processing
1 parent 0c5c906 commit 516790f

File tree

7 files changed

+895
-845
lines changed

7 files changed

+895
-845
lines changed

contributing/samples/hello_world_gemma/agent.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
import random
1717

1818
from google.adk.agents.llm_agent import Agent
19-
from google.adk.models.google_llm import Gemma
20-
from google.adk.models.google_llm import gemma_functions_after_model_callback
21-
from google.adk.models.google_llm import gemma_functions_before_model_callback
19+
from google.adk.models.gemma_llm import Gemma
2220
from google.genai.types import GenerateContentConfig
2321

2422

@@ -90,8 +88,6 @@ async def check_prime(nums: list[int]) -> str:
9088
roll_die,
9189
check_prime,
9290
],
93-
before_model_callback=gemma_functions_before_model_callback,
94-
after_model_callback=gemma_functions_after_model_callback,
9591
generate_content_config=GenerateContentConfig(
9692
temperature=1.0,
9793
top_p=0.95,

src/google/adk/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
"""Defines the interface to support a model."""
1616

1717
from .base_llm import BaseLlm
18+
from .gemma_llm import Gemma
1819
from .google_llm import Gemini
19-
from .google_llm import Gemma
2020
from .llm_request import LlmRequest
2121
from .llm_response import LlmResponse
2222
from .registry import LLMRegistry

src/google/adk/models/gemma_llm.py

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import cached_property
16+
import json
17+
import logging
18+
import re
19+
from typing import Any
20+
from typing import AsyncGenerator
21+
22+
from google.adk.models.google_llm import Gemini
23+
from google.adk.models.llm_request import LlmRequest
24+
from google.adk.models.llm_response import LlmResponse
25+
from google.adk.utils.variant_utils import GoogleLLMVariant
26+
from google.genai import types
27+
from google.genai.types import Content
28+
from google.genai.types import FunctionDeclaration
29+
from google.genai.types import Part
30+
from pydantic import AliasChoices
31+
from pydantic import BaseModel
32+
from pydantic import Field
33+
from pydantic import ValidationError
34+
from typing_extensions import override
35+
36+
logger = logging.getLogger('google_adk.' + __name__)
37+
38+
39+
class GemmaFunctionCallModel(BaseModel):
40+
"""Flexible Pydantic model for parsing inline Gemma function call responses."""
41+
42+
name: str = Field(validation_alias=AliasChoices('name', 'function'))
43+
parameters: dict[str, Any] = Field(
44+
validation_alias=AliasChoices('parameters', 'args')
45+
)
46+
47+
48+
class Gemma(Gemini):
49+
"""Integration for Gemma models exposed via the Gemini API.
50+
51+
Only Gemma 3 models are supported at this time. For agentic use cases,
52+
use of gemma-3-27b-it and gemma-3-12b-it are strongly recommended.
53+
54+
For full documentation, see: https://ai.google.dev/gemma/docs/core/
55+
56+
NOTE: Gemma does **NOT** support system instructions. Any system instructions
57+
will be replaced with an initial *user* prompt in the LLM request. If system
58+
instructions change over the course of agent execution, the initial content
59+
**SHOULD** be replaced. Special care is warranted here.
60+
See: https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions
61+
62+
NOTE: Gemma's function calling support is limited. It does not have full access to the
63+
same built-in tools as Gemini. It also does not have special API support for tools and
64+
functions. Rather, tools must be passed in via a `user` prompt, and extracted from model
65+
responses based on approximate shape.
66+
67+
NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY** supports
68+
usage via the Gemini API.
69+
"""
70+
71+
model: str = (
72+
'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it]
73+
)
74+
75+
@classmethod
76+
@override
77+
def supported_models(cls) -> list[str]:
78+
"""Provides the list of supported models.
79+
80+
Returns:
81+
A list of supported models.
82+
"""
83+
84+
return [
85+
r'gemma-3.*',
86+
]
87+
88+
@cached_property
89+
def _api_backend(self) -> GoogleLLMVariant:
90+
return GoogleLLMVariant.GEMINI_API
91+
92+
def _move_function_calls_into_system_instruction(
93+
self, llm_request: LlmRequest
94+
):
95+
if llm_request.model is None or not llm_request.model.startswith('gemma-3'):
96+
return
97+
98+
# Iterate through the existing contents to find and convert function calls and responses
99+
# from text parts, as Gemma models don't directly support function calling.
100+
new_contents: list[Content] = []
101+
for content_item in llm_request.contents:
102+
(
103+
new_parts_for_content,
104+
has_function_response_part,
105+
has_function_call_part,
106+
) = _convert_content_parts_for_gemma(content_item)
107+
108+
if has_function_response_part:
109+
if new_parts_for_content:
110+
new_contents.append(Content(role='user', parts=new_parts_for_content))
111+
elif has_function_call_part:
112+
if new_parts_for_content:
113+
new_contents.append(
114+
Content(role='model', parts=new_parts_for_content)
115+
)
116+
else:
117+
new_contents.append(content_item)
118+
119+
llm_request.contents = new_contents
120+
121+
if not llm_request.config.tools:
122+
return
123+
124+
all_function_declarations: list[FunctionDeclaration] = []
125+
for tool_item in llm_request.config.tools:
126+
if isinstance(tool_item, types.Tool) and tool_item.function_declarations:
127+
all_function_declarations.extend(tool_item.function_declarations)
128+
129+
if all_function_declarations:
130+
system_instruction = _build_gemma_function_system_instruction(
131+
all_function_declarations
132+
)
133+
llm_request.append_instructions([system_instruction])
134+
135+
llm_request.config.tools = []
136+
137+
def _extract_function_calls_from_response(self, llm_response: LlmResponse):
138+
if llm_response.partial or (llm_response.turn_complete is True):
139+
return
140+
141+
if not llm_response.content:
142+
return
143+
144+
if not llm_response.content.parts:
145+
return
146+
147+
if len(llm_response.content.parts) > 1:
148+
return
149+
150+
response_text = llm_response.content.parts[0].text
151+
if not response_text:
152+
return
153+
154+
try:
155+
json_candidate = None
156+
157+
markdown_code_block_pattern = re.compile(
158+
r'```(?:(json|tool_code))?\s*(.*?)\s*```', re.DOTALL
159+
)
160+
block_match = markdown_code_block_pattern.search(response_text)
161+
162+
if block_match:
163+
json_candidate = block_match.group(2).strip()
164+
else:
165+
found, json_text = _get_last_valid_json_substring(response_text)
166+
if found:
167+
json_candidate = json_text
168+
169+
if not json_candidate:
170+
return
171+
172+
function_call_parsed = GemmaFunctionCallModel.model_validate_json(
173+
json_candidate
174+
)
175+
function_call = types.FunctionCall(
176+
name=function_call_parsed.name,
177+
args=function_call_parsed.parameters,
178+
)
179+
function_call_part = Part(function_call=function_call)
180+
llm_response.content.parts = [function_call_part]
181+
except (json.JSONDecodeError, ValidationError) as e:
182+
logger.debug(
183+
f'Error attempting to parse JSON into function call. Leaving as text'
184+
f' response. %s',
185+
e,
186+
)
187+
except Exception as e:
188+
logger.warning('Error processing Gemma function call response: %s', e)
189+
190+
@override
191+
async def _preprocess_request(self, llm_request: LlmRequest) -> None:
192+
self._move_function_calls_into_system_instruction(llm_request=llm_request)
193+
194+
if system_instruction := llm_request.config.system_instruction:
195+
contents = llm_request.contents
196+
instruction_content = Content(
197+
role='user', parts=[Part.from_text(text=system_instruction)]
198+
)
199+
200+
# NOTE: if history is preserved, we must include the system instructions ONLY once at the beginning
201+
# of any chain of contents.
202+
if contents:
203+
if contents[0] != instruction_content:
204+
# only prepend if it hasn't already been done
205+
llm_request.contents = [instruction_content] + contents
206+
207+
llm_request.config.system_instruction = None
208+
209+
return await super()._preprocess_request(llm_request)
210+
211+
@override
212+
async def generate_content_async(
213+
self, llm_request: LlmRequest, stream: bool = False
214+
) -> AsyncGenerator[LlmResponse, None]:
215+
"""Sends a request to the Gemma model.
216+
217+
Args:
218+
llm_request: LlmRequest, the request to send to the Gemini model.
219+
stream: bool = False, whether to do streaming call.
220+
221+
Yields:
222+
LlmResponse: The model response.
223+
"""
224+
# print(f'{llm_request=}')
225+
assert llm_request.model.startswith('gemma-'), (
226+
f'Requesting a non-Gemma model ({llm_request.model}) with the Gemma LLM'
227+
' is not supported.'
228+
)
229+
230+
async for response in super().generate_content_async(llm_request, stream):
231+
self._extract_function_calls_from_response(response)
232+
yield response
233+
234+
235+
def _convert_content_parts_for_gemma(
236+
content_item: Content,
237+
) -> tuple[list[Part], bool, bool]:
238+
"""Converts function call/response parts within a content item to text parts.
239+
240+
Args:
241+
content_item: The original Content item.
242+
243+
Returns:
244+
A tuple containing:
245+
- A list of new Part objects with function calls/responses converted to text.
246+
- A boolean indicating if any function response parts were found.
247+
- A boolean indicating if any function call parts were found.
248+
"""
249+
new_parts: list[Part] = []
250+
has_function_response_part = False
251+
has_function_call_part = False
252+
253+
for part in content_item.parts:
254+
if func_response := part.function_response:
255+
has_function_response_part = True
256+
response_text = (
257+
f'Invoking tool `{func_response.name}` produced:'
258+
f' `{json.dumps(func_response.response)}`.'
259+
)
260+
new_parts.append(Part.from_text(text=response_text))
261+
elif func_call := part.function_call:
262+
has_function_call_part = True
263+
new_parts.append(
264+
Part.from_text(text=func_call.model_dump_json(exclude_none=True))
265+
)
266+
else:
267+
new_parts.append(part)
268+
return new_parts, has_function_response_part, has_function_call_part
269+
270+
271+
def _build_gemma_function_system_instruction(
272+
function_declarations: list[FunctionDeclaration],
273+
) -> str:
274+
"""Constructs the system instruction string for Gemma function calling."""
275+
if not function_declarations:
276+
return ''
277+
278+
system_instruction_prefix = 'You have access to the following functions:\n['
279+
instruction_parts = []
280+
for func in function_declarations:
281+
instruction_parts.append(func.model_dump_json(exclude_none=True))
282+
283+
separator = ',\n'
284+
system_instruction = (
285+
f'{system_instruction_prefix}{separator.join(instruction_parts)}\n]\n'
286+
)
287+
288+
system_instruction += (
289+
'When you call a function, you MUST respond in the format of: '
290+
"""{"name": function name, "parameters": dictionary of argument name and its value}\n"""
291+
'When you call a function, you MUST NOT include any other text in the'
292+
' response.\n'
293+
)
294+
return system_instruction
295+
296+
297+
def _get_last_valid_json_substring(text: str) -> tuple[bool, str | None]:
298+
"""Attempts to find and return the last valid JSON object in a string.
299+
300+
This function is designed to extract JSON that might be embedded in a larger
301+
text, potentially with introductory or concluding remarks. It will always chose
302+
the last block of valid json found within the supplied text (if it exists).
303+
304+
Args:
305+
text: The input string to search for JSON objects.
306+
307+
Returns:
308+
A tuple:
309+
- bool: True if a valid JSON substring was found, False otherwise.
310+
- str | None: The last valid JSON substring found, or None if none was
311+
found.
312+
"""
313+
decoder = json.JSONDecoder()
314+
last_json_str = None
315+
start_pos = 0
316+
while start_pos < len(text):
317+
try:
318+
first_brace_index = text.index('{', start_pos)
319+
_, end_index = decoder.raw_decode(text[first_brace_index:])
320+
last_json_str = text[first_brace_index : first_brace_index + end_index]
321+
start_pos = first_brace_index + end_index
322+
except json.JSONDecodeError:
323+
start_pos = first_brace_index + 1
324+
except ValueError:
325+
break
326+
327+
if last_json_str:
328+
return True, last_json_str
329+
return False, None

0 commit comments

Comments
 (0)