diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index b97932d042..39be391e46 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -550,6 +550,28 @@ def _get_root_agent(self, agent_or_app: BaseAgent | App) -> BaseAgent: return agent_or_app.root_agent return agent_or_app + def _get_effective_modalities( + self, root_agent: BaseAgent, requested_modalities: List[str] + ) -> List[str]: + """Determines effective modalities, forcing AUDIO for native-audio models. + + Native-audio models only support AUDIO modality. This method detects + native-audio models by checking if the model name contains "native-audio" + and forces AUDIO modality for those models. + + Args: + root_agent: The root agent of the application. + requested_modalities: The modalities requested by the client. + + Returns: + The effective modalities to use. + """ + model = getattr(root_agent, "model", None) + model_name = model if isinstance(model, str) else "" + if "native-audio" in model_name: + return ["AUDIO"] + return requested_modalities + def _create_runner(self, agentic_app: App) -> Runner: """Create a runner with common services.""" return Runner( @@ -1652,7 +1674,10 @@ async def run_agent_live( async def forward_events(): runner = await self.get_runner_async(app_name) - run_config = RunConfig(response_modalities=modalities) + effective_modalities = self._get_effective_modalities( + runner.app.root_agent, modalities + ) + run_config = RunConfig(response_modalities=effective_modalities) async with Aclosing( runner.run_live( session=session, diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 0c69605349..42b17f93ce 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -29,6 +29,7 @@ from fastapi.testclient import TestClient from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App from google.adk.artifacts.base_artifact_service import ArtifactVersion @@ -1411,5 +1412,91 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path): assert not (tmp_path / "app" / "tmp" / "escape.yaml").exists() +def test_native_audio_model_forces_audio_modality(): + """Test that native-audio models force AUDIO modality.""" + from google.adk.cli.adk_web_server import AdkWebServer + + native_audio_agent = LlmAgent( + name="native_audio_agent", + model="gemini-live-2.5-flash-native-audio", + ) + + adk_web_server = AdkWebServer( + agent_loader=MagicMock(), + session_service=MagicMock(), + memory_service=MagicMock(), + artifact_service=MagicMock(), + credential_service=MagicMock(), + eval_sets_manager=MagicMock(), + eval_set_results_manager=MagicMock(), + agents_dir=".", + ) + + # Test: requesting TEXT should be forced to AUDIO + modalities = adk_web_server._get_effective_modalities( + native_audio_agent, ["TEXT"] + ) + assert modalities == ["AUDIO"] + + # Test: requesting AUDIO should stay AUDIO + modalities = adk_web_server._get_effective_modalities( + native_audio_agent, ["AUDIO"] + ) + assert modalities == ["AUDIO"] + + +def test_non_native_audio_model_keeps_requested_modality(): + """Test that non-native-audio models keep the requested modality.""" + from google.adk.cli.adk_web_server import AdkWebServer + + regular_agent = LlmAgent( + name="regular_agent", + model="gemini-2.5-flash", + ) + + adk_web_server = AdkWebServer( + agent_loader=MagicMock(), + session_service=MagicMock(), + memory_service=MagicMock(), + artifact_service=MagicMock(), + credential_service=MagicMock(), + eval_sets_manager=MagicMock(), + eval_set_results_manager=MagicMock(), + agents_dir=".", + ) + + # Test: requesting TEXT should stay TEXT + modalities = adk_web_server._get_effective_modalities(regular_agent, ["TEXT"]) + assert modalities == ["TEXT"] + + # Test: requesting AUDIO should stay AUDIO + modalities = adk_web_server._get_effective_modalities( + regular_agent, ["AUDIO"] + ) + assert modalities == ["AUDIO"] + + +def test_agent_without_model_attribute(): + """Test that agents without model attribute keep requested modality.""" + from google.adk.cli.adk_web_server import AdkWebServer + + base_agent = DummyAgent(name="base_agent") + + adk_web_server = AdkWebServer( + agent_loader=MagicMock(), + session_service=MagicMock(), + memory_service=MagicMock(), + artifact_service=MagicMock(), + credential_service=MagicMock(), + eval_sets_manager=MagicMock(), + eval_set_results_manager=MagicMock(), + agents_dir=".", + ) + + # Test: BaseAgent without model attr should keep requested modality + modalities = adk_web_server._get_effective_modalities(base_agent, ["TEXT"]) + assert modalities == ["TEXT"] + + if __name__ == "__main__": pytest.main(["-xvs", __file__])