-
Notifications
You must be signed in to change notification settings - Fork 2.8k
fix: Force AUDIO modality for native-audio models in /run_live (#4206) #4232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a446492
9fc0b22
45ea2bc
4cd51b4
2b88478
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -550,6 +550,28 @@ def _get_root_agent(self, agent_or_app: BaseAgent | App) -> BaseAgent: | |
| return agent_or_app.root_agent | ||
| return agent_or_app | ||
|
|
||
| def _get_effective_modalities( | ||
| self, root_agent: BaseAgent, requested_modalities: List[str] | ||
| ) -> List[str]: | ||
| """Determines effective modalities, forcing AUDIO for native-audio models. | ||
|
|
||
| Native-audio models only support AUDIO modality. This method detects | ||
| native-audio models by checking if the model name contains "native-audio" | ||
| and forces AUDIO modality for those models. | ||
|
|
||
| Args: | ||
| root_agent: The root agent of the application. | ||
| requested_modalities: The modalities requested by the client. | ||
|
|
||
| Returns: | ||
| The effective modalities to use. | ||
| """ | ||
| model = getattr(root_agent, "model", None) | ||
| model_name = model if isinstance(model, str) else "" | ||
| if "native-audio" in model_name: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The string For example: _NATIVE_AUDIO_MODEL_TAG = "native-audio"This would allow you to reference |
||
| return ["AUDIO"] | ||
| return requested_modalities | ||
|
|
||
| def _create_runner(self, agentic_app: App) -> Runner: | ||
| """Create a runner with common services.""" | ||
| return Runner( | ||
|
|
@@ -1652,7 +1674,10 @@ async def run_agent_live( | |
|
|
||
| async def forward_events(): | ||
| runner = await self.get_runner_async(app_name) | ||
| run_config = RunConfig(response_modalities=modalities) | ||
| effective_modalities = self._get_effective_modalities( | ||
| runner.app.root_agent, modalities | ||
| ) | ||
| run_config = RunConfig(response_modalities=effective_modalities) | ||
| async with Aclosing( | ||
| runner.run_live( | ||
| session=session, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |
|
|
||
| from fastapi.testclient import TestClient | ||
| from google.adk.agents.base_agent import BaseAgent | ||
| from google.adk.agents.llm_agent import LlmAgent | ||
| from google.adk.agents.run_config import RunConfig | ||
| from google.adk.apps.app import App | ||
| from google.adk.artifacts.base_artifact_service import ArtifactVersion | ||
|
|
@@ -1411,5 +1412,91 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path): | |
| assert not (tmp_path / "app" / "tmp" / "escape.yaml").exists() | ||
|
|
||
|
|
||
| def test_native_audio_model_forces_audio_modality(): | ||
| """Test that native-audio models force AUDIO modality.""" | ||
| from google.adk.cli.adk_web_server import AdkWebServer | ||
|
|
||
| native_audio_agent = LlmAgent( | ||
| name="native_audio_agent", | ||
| model="gemini-live-2.5-flash-native-audio", | ||
| ) | ||
|
|
||
| adk_web_server = AdkWebServer( | ||
| agent_loader=MagicMock(), | ||
| session_service=MagicMock(), | ||
| memory_service=MagicMock(), | ||
| artifact_service=MagicMock(), | ||
| credential_service=MagicMock(), | ||
| eval_sets_manager=MagicMock(), | ||
| eval_set_results_manager=MagicMock(), | ||
| agents_dir=".", | ||
| ) | ||
|
Comment on lines
+1424
to
+1433
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The instantiation of Here's an example of what that fixture could look like: @pytest.fixture
def adk_web_server_for_modality_tests():
"""Provides an AdkWebServer instance with mocked services for modality tests."""
from google.adk.cli.adk_web_server import AdkWebServer
return AdkWebServer(
agent_loader=MagicMock(),
session_service=MagicMock(),
memory_service=MagicMock(),
artifact_service=MagicMock(),
credential_service=MagicMock(),
eval_sets_manager=MagicMock(),
eval_set_results_manager=MagicMock(),
agents_dir=".",
)Each test could then accept |
||
|
|
||
| # Test: requesting TEXT should be forced to AUDIO | ||
| modalities = adk_web_server._get_effective_modalities( | ||
| native_audio_agent, ["TEXT"] | ||
| ) | ||
| assert modalities == ["AUDIO"] | ||
|
|
||
| # Test: requesting AUDIO should stay AUDIO | ||
| modalities = adk_web_server._get_effective_modalities( | ||
| native_audio_agent, ["AUDIO"] | ||
| ) | ||
| assert modalities == ["AUDIO"] | ||
|
|
||
|
|
||
| def test_non_native_audio_model_keeps_requested_modality(): | ||
| """Test that non-native-audio models keep the requested modality.""" | ||
| from google.adk.cli.adk_web_server import AdkWebServer | ||
|
|
||
| regular_agent = LlmAgent( | ||
| name="regular_agent", | ||
| model="gemini-2.5-flash", | ||
| ) | ||
|
|
||
| adk_web_server = AdkWebServer( | ||
| agent_loader=MagicMock(), | ||
| session_service=MagicMock(), | ||
| memory_service=MagicMock(), | ||
| artifact_service=MagicMock(), | ||
| credential_service=MagicMock(), | ||
| eval_sets_manager=MagicMock(), | ||
| eval_set_results_manager=MagicMock(), | ||
| agents_dir=".", | ||
| ) | ||
|
|
||
| # Test: requesting TEXT should stay TEXT | ||
| modalities = adk_web_server._get_effective_modalities(regular_agent, ["TEXT"]) | ||
| assert modalities == ["TEXT"] | ||
|
|
||
| # Test: requesting AUDIO should stay AUDIO | ||
| modalities = adk_web_server._get_effective_modalities( | ||
| regular_agent, ["AUDIO"] | ||
| ) | ||
| assert modalities == ["AUDIO"] | ||
|
|
||
|
|
||
| def test_agent_without_model_attribute(): | ||
| """Test that agents without model attribute keep requested modality.""" | ||
| from google.adk.cli.adk_web_server import AdkWebServer | ||
|
|
||
| base_agent = DummyAgent(name="base_agent") | ||
|
|
||
| adk_web_server = AdkWebServer( | ||
| agent_loader=MagicMock(), | ||
| session_service=MagicMock(), | ||
| memory_service=MagicMock(), | ||
| artifact_service=MagicMock(), | ||
| credential_service=MagicMock(), | ||
| eval_sets_manager=MagicMock(), | ||
| eval_set_results_manager=MagicMock(), | ||
| agents_dir=".", | ||
| ) | ||
|
|
||
| # Test: BaseAgent without model attr should keep requested modality | ||
| modalities = adk_web_server._get_effective_modalities(base_agent, ["TEXT"]) | ||
| assert modalities == ["TEXT"] | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main(["-xvs", __file__]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current logic for extracting the model name only handles the case where the
modelattribute is a string. TheLlmAgent.modelattribute can also be aBaseLlmobject, in which caseisinstance(model, str)would be false,model_namewould become an empty string, and the check for "native-audio" would fail.To make this more robust, you should also handle the case where
modelis an object (likeBaseLlm) that has amodelstring attribute. It would also be beneficial to add a test case for anLlmAgentinitialized with aBaseLlmobject to ensure full coverage.