diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 2787c8d9d4..03fde10ed8 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -14,10 +14,14 @@ from __future__ import annotations from datetime import datetime +import importlib.util import json +import logging import os import shutil import subprocess +import sys +import traceback from typing import Final from typing import Optional import warnings @@ -25,6 +29,8 @@ import click from packaging.version import parse +logger = logging.getLogger('google_adk.' + __name__) + _IS_WINDOWS = os.name == 'nt' _GCLOUD_CMD = 'gcloud.cmd' if _IS_WINDOWS else 'gcloud' _LOCAL_STORAGE_FLAG_MIN_VERSION: Final[str] = '1.21.0' @@ -99,15 +105,33 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: """ _AGENT_ENGINE_APP_TEMPLATE: Final[str] = """ +import logging import os +import sys +import cloudpickle import vertexai from vertexai.agent_engines import AdkApp +_logger = logging.getLogger("google_adk." + __name__) + if {is_config_agent}: from google.adk.agents import config_agent_utils root_agent = config_agent_utils.from_config("{agent_folder}/root_agent.yaml") else: from .agent import {adk_app_object} + # Register the agent module for pickle-by-value serialization. + # This ensures custom BaseLlm implementations are serialized with their + # full class definition instead of just the import path, which fixes + # the "query method not found" error when using custom LLM clients. + from . import agent as _agent_module + cloudpickle.register_pickle_by_value(_agent_module) + # Also register any submodules that contain custom classes + for name, module in list(sys.modules.items()): + if module is not None and name.startswith(_agent_module.__name__.rsplit('.', 1)[0] + '.'): + try: + cloudpickle.register_pickle_by_value(module) + except Exception as e: + _logger.debug("Failed to register module %s for pickle-by-value: %s", name, e) if {express_mode}: # Whether or not to use Express Mode vertexai.init(api_key=os.environ.get("GOOGLE_API_KEY")) @@ -464,6 +488,218 @@ def _validate_gcloud_extra_args( ) +def _validate_agent_object(exported_obj: object, adk_app_object: str) -> None: + """Validates that the exported agent/app object is properly configured. + + This function performs deeper validation beyond just checking that the + object exists. It verifies that agents with custom BaseLlm implementations + have properly configured models that will work at Agent Engine runtime. + + Args: + exported_obj: The exported root_agent or app object. + adk_app_object: The name of the exported object ('root_agent' or 'app'). + + Raises: + click.ClickException: If the agent object is not properly configured. + """ + # Import here to avoid circular imports + try: + from google.adk.agents import BaseAgent + from google.adk.models import BaseLlm + except ImportError as e: + # If we can't import these, skip validation. This can happen in partial + # ADK environments or when optional dependencies are not installed. + logger.debug( + 'Skipping agent object validation: could not import ADK classes: %s', e + ) + return + + # For 'app' exports (AdkApp instances), we don't validate the internal agent + if adk_app_object == 'app': + return + + # For 'root_agent' exports, validate it's an agent with a valid model + if not isinstance(exported_obj, BaseAgent): + click.secho( + f'Warning: {adk_app_object} is not a BaseAgent instance. ' + 'Skipping model validation.', + fg='yellow', + ) + return + + # Check if the agent has a model attribute + model = getattr(exported_obj, 'model', None) + if model is None: + # Some agents might not have a model (e.g., workflow agents) + return + + # If the model is a string, it will be resolved by LLMRegistry at runtime + if isinstance(model, str): + return + + # If the model is a BaseLlm instance, validate it + if isinstance(model, BaseLlm): + model_class = type(model) + model_module = model_class.__module__ + + # Check if this is a custom BaseLlm (not from google.adk.models) + if not model_module.startswith('google.adk.models'): + click.echo( + f'Detected custom BaseLlm implementation: {model_class.__name__} ' + f'from {model_module}' + ) + + # Validate that the custom model can be pickled (required for Agent Engine) + try: + import cloudpickle + + cloudpickle.dumps(model) + click.echo( + f'Custom model {model_class.__name__} passed serialization check' + ) + except Exception as e: + raise click.ClickException( + f'Custom BaseLlm implementation {model_class.__name__} cannot be ' + f'serialized:\n{e}\n\n' + 'Agent Engine requires all custom LLM implementations to be ' + 'serializable. Please ensure:\n' + '1. Your custom BaseLlm class does not have non-serializable ' + 'attributes (file handles, connections, etc.)\n' + "2. All fields are JSON-serializable or use Pydantic's " + 'ConfigDict(arbitrary_types_allowed=True)\n' + '3. Consider implementing __getstate__ and __setstate__ methods ' + 'for custom serialization logic' + ) from e + + # Check if the model class is defined in a file within the agent folder + # by verifying the module can be imported with relative imports + if '.' not in model_module or model_module.count('.') < 2: + click.secho( + f'Warning: Custom model {model_class.__name__} is defined in ' + f'{model_module}. For Agent Engine deployment, ensure this module ' + 'is within your agent folder and uses relative imports.', + fg='yellow', + ) + + +def _validate_agent_import( + agent_src_path: str, + adk_app_object: str, + is_config_agent: bool, +) -> None: + """Validates that the agent module can be imported successfully. + + This pre-deployment validation catches common issues like missing + dependencies or import errors in custom BaseLlm implementations before + the agent is deployed to Agent Engine. This provides clearer error + messages and prevents deployments that would fail at runtime. + + Args: + agent_src_path: Path to the staged agent source code. + adk_app_object: The Python object name to import ('root_agent' or 'app'). + is_config_agent: Whether this is a config-based agent. + + Raises: + click.ClickException: If the agent module cannot be imported. + """ + if is_config_agent: + # Config agents are loaded from YAML, skip Python import validation + return + + agent_module_path = os.path.join(agent_src_path, 'agent.py') + if not os.path.exists(agent_module_path): + raise click.ClickException( + f'Agent module not found at {agent_module_path}. ' + 'Please ensure your agent folder contains an agent.py file.' + ) + + # Add the parent directory to sys.path temporarily for import resolution + parent_dir = os.path.dirname(agent_src_path) + module_name = os.path.basename(agent_src_path) + + original_sys_path = sys.path.copy() + try: + # Add parent directory to path so imports work correctly + if parent_dir not in sys.path: + sys.path.insert(0, parent_dir) + + # Load the agent module spec + spec = importlib.util.spec_from_file_location( + f'{module_name}.agent', + agent_module_path, + submodule_search_locations=[agent_src_path], + ) + if spec is None or spec.loader is None: + raise click.ClickException( + f'Failed to load module spec from {agent_module_path}' + ) + + # Try to load the module + module = importlib.util.module_from_spec(spec) + sys.modules[f'{module_name}.agent'] = module + + try: + spec.loader.exec_module(module) + except ImportError as e: + error_msg = str(e) + tb = traceback.format_exc() + + # Check for common issues + if 'BaseLlm' in tb or 'base_llm' in tb.lower(): + raise click.ClickException( + 'Failed to import agent module due to a BaseLlm-related error:\n' + f'{error_msg}\n\n' + 'This error often occurs when deploying agents with custom LLM ' + 'implementations. Please ensure:\n' + '1. All custom LLM classes are defined in files within your agent ' + 'folder\n' + '2. All required dependencies are listed in requirements.txt\n' + '3. Import paths use relative imports (e.g., "from .my_llm import ' + 'MyLlm")\n' + '4. Your custom BaseLlm implementation is serializable' + ) from e + else: + raise click.ClickException( + f'Failed to import agent module:\n{error_msg}\n\n' + 'Please ensure all dependencies are listed in requirements.txt ' + 'and all imports are resolvable.\n\n' + f'Full traceback:\n{tb}' + ) from e + except Exception as e: + tb = traceback.format_exc() + raise click.ClickException( + f'Error while loading agent module:\n{e}\n\n' + 'Please check your agent code for errors.\n\n' + f'Full traceback:\n{tb}' + ) from e + + # Check that the expected object exists + if not hasattr(module, adk_app_object): + available_attrs = [ + attr for attr in dir(module) if not attr.startswith('_') + ] + raise click.ClickException( + f"Agent module does not export '{adk_app_object}'. " + f'Available exports: {available_attrs}\n\n' + 'Please ensure your agent.py exports either "root_agent" or "app".' + ) + + # Validate that the exported object is properly configured for Agent Engine + exported_obj = getattr(module, adk_app_object) + _validate_agent_object(exported_obj, adk_app_object) + + click.echo( + 'Agent module validation successful: ' + f'found "{adk_app_object}" in agent.py' + ) + + finally: + # Restore original sys.path + sys.path[:] = original_sys_path + # Clean up the module from sys.modules + sys.modules.pop(f'{module_name}.agent', None) + + def _get_service_option_by_adk_version( adk_version: str, session_uri: Optional[str], @@ -952,6 +1188,10 @@ def to_agent_engine( click.echo(f'Config agent detected: {config_root_agent_file}') is_config_agent = True + # Validate that the agent module can be imported before deployment + click.echo('Validating agent module...') + _validate_agent_import(agent_src_path, adk_app_object, is_config_agent) + adk_app_file = os.path.join(temp_folder, f'{adk_app}.py') if adk_app_object == 'root_agent': adk_app_type = 'agent' diff --git a/tests/unittests/cli/utils/test_cli_deploy.py b/tests/unittests/cli/utils/test_cli_deploy.py index 43b9e07a2e..12dcac97d3 100644 --- a/tests/unittests/cli/utils/test_cli_deploy.py +++ b/tests/unittests/cli/utils/test_cli_deploy.py @@ -14,7 +14,6 @@ """Tests for utilities in cli_deploy.""" - from __future__ import annotations import importlib @@ -83,7 +82,9 @@ def agent_dir(tmp_path: Path) -> Callable[[bool, bool], Path]: def _factory(include_requirements: bool, include_env: bool) -> Path: base = tmp_path / "agent" base.mkdir() - (base / "agent.py").write_text("# dummy agent") + (base / "agent.py").write_text( + "# dummy agent\nroot_agent = 'dummy_agent'\n" + ) (base / "__init__.py").touch() if include_requirements: (base / "requirements.txt").write_text("pytest\n") @@ -285,6 +286,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: assert "adk_app = AdkApp(" in content assert "agent=root_agent" in content assert "enable_tracing=True" in content + # Verify cloudpickle pickle-by-value registration for custom BaseLlm support + assert "import cloudpickle" in content + assert "cloudpickle.register_pickle_by_value" in content reqs_path = tmp_dir / "requirements.txt" assert reqs_path.is_file() assert "google-cloud-aiplatform[adk,agent_engines]" in reqs_path.read_text() @@ -394,3 +398,351 @@ def mock_subprocess_run(*args, **kwargs): # 4. Verify cleanup assert str(rmtree_recorder.get_last_call_args()[0]) == str(tmp_path) + + +# _validate_agent_import tests +class TestValidateAgentImport: + """Tests for the _validate_agent_import function.""" + + def test_skips_config_agents(self, tmp_path: Path) -> None: + """Config agents should skip validation.""" + # This should not raise even with no agent.py file + cli_deploy._validate_agent_import( + str(tmp_path), "root_agent", is_config_agent=True + ) + + def test_raises_on_missing_agent_module(self, tmp_path: Path) -> None: + """Should raise when agent.py is missing.""" + with pytest.raises(click.ClickException) as exc_info: + cli_deploy._validate_agent_import( + str(tmp_path), "root_agent", is_config_agent=False + ) + assert "Agent module not found" in str(exc_info.value) + + def test_raises_on_missing_export(self, tmp_path: Path) -> None: + """Should raise when the expected export is missing.""" + agent_file = tmp_path / "agent.py" + agent_file.write_text("some_other_var = 'hello'\n") + (tmp_path / "__init__.py").touch() + + with pytest.raises(click.ClickException) as exc_info: + cli_deploy._validate_agent_import( + str(tmp_path), "root_agent", is_config_agent=False + ) + assert "does not export 'root_agent'" in str(exc_info.value) + assert "some_other_var" in str(exc_info.value) + + def test_success_with_root_agent_export(self, tmp_path: Path) -> None: + """Should succeed when root_agent is exported.""" + agent_file = tmp_path / "agent.py" + agent_file.write_text("root_agent = 'my_agent'\n") + (tmp_path / "__init__.py").touch() + + # Should not raise + cli_deploy._validate_agent_import( + str(tmp_path), "root_agent", is_config_agent=False + ) + + def test_success_with_app_export(self, tmp_path: Path) -> None: + """Should succeed when app is exported.""" + agent_file = tmp_path / "agent.py" + agent_file.write_text("app = 'my_app'\n") + (tmp_path / "__init__.py").touch() + + # Should not raise + cli_deploy._validate_agent_import( + str(tmp_path), "app", is_config_agent=False + ) + + def test_raises_on_import_error(self, tmp_path: Path) -> None: + """Should raise with helpful message on ImportError.""" + agent_file = tmp_path / "agent.py" + agent_file.write_text("from nonexistent_module import something\n") + (tmp_path / "__init__.py").touch() + + with pytest.raises(click.ClickException) as exc_info: + cli_deploy._validate_agent_import( + str(tmp_path), "root_agent", is_config_agent=False + ) + assert "Failed to import agent module" in str(exc_info.value) + assert "nonexistent_module" in str(exc_info.value) + + def test_raises_on_basellm_import_error(self, tmp_path: Path) -> None: + """Should provide specific guidance for BaseLlm import errors.""" + agent_file = tmp_path / "agent.py" + agent_file.write_text( + "from google.adk.models.base_llm import NonexistentBaseLlm\n" + ) + (tmp_path / "__init__.py").touch() + + with pytest.raises(click.ClickException) as exc_info: + cli_deploy._validate_agent_import( + str(tmp_path), "root_agent", is_config_agent=False + ) + assert "BaseLlm-related error" in str(exc_info.value) + assert "custom LLM" in str(exc_info.value) + + def test_raises_on_syntax_error(self, tmp_path: Path) -> None: + """Should raise on syntax errors in agent.py.""" + agent_file = tmp_path / "agent.py" + agent_file.write_text("def invalid syntax here:\n") + (tmp_path / "__init__.py").touch() + + with pytest.raises(click.ClickException) as exc_info: + cli_deploy._validate_agent_import( + str(tmp_path), "root_agent", is_config_agent=False + ) + assert "Error while loading agent module" in str(exc_info.value) + + def test_cleans_up_sys_modules(self, tmp_path: Path) -> None: + """Should clean up sys.modules after validation.""" + agent_file = tmp_path / "agent.py" + agent_file.write_text("root_agent = 'my_agent'\n") + (tmp_path / "__init__.py").touch() + + module_name = tmp_path.name + agent_module_key = f"{module_name}.agent" + + # Ensure module is not in sys.modules before + assert agent_module_key not in sys.modules + + cli_deploy._validate_agent_import( + str(tmp_path), "root_agent", is_config_agent=False + ) + + # Ensure module is cleaned up after + assert agent_module_key not in sys.modules + + def test_restores_sys_path(self, tmp_path: Path) -> None: + """Should restore sys.path after validation.""" + agent_file = tmp_path / "agent.py" + agent_file.write_text("root_agent = 'my_agent'\n") + (tmp_path / "__init__.py").touch() + + original_path = sys.path.copy() + + cli_deploy._validate_agent_import( + str(tmp_path), "root_agent", is_config_agent=False + ) + + assert sys.path == original_path + + +# _validate_agent_object tests +class TestValidateAgentObject: + """Tests for the _validate_agent_object function.""" + + def test_skips_app_export(self) -> None: + """Should skip validation for 'app' exports.""" + # Should not raise even with an invalid object + cli_deploy._validate_agent_object("not_an_agent", "app") + + def test_warns_on_non_baseagent(self) -> None: + """Should warn but not raise for non-BaseAgent objects.""" + # Should not raise + cli_deploy._validate_agent_object("just_a_string", "root_agent") + + def test_skips_string_models(self) -> None: + """Should skip validation when model is a string.""" + from google.adk.agents import Agent + + # Create agent with string model + agent = Agent(model="gemini-2.0-flash", name="test_agent") + # Should not raise + cli_deploy._validate_agent_object(agent, "root_agent") + + def test_validates_custom_basellm_serialization(self, tmp_path: Path) -> None: + """Should validate that custom BaseLlm can be serialized.""" + from typing import AsyncGenerator + + from google.adk.agents import Agent + from google.adk.models import BaseLlm + + # Create a simple serializable custom BaseLlm + class SerializableCustomLlm(BaseLlm): + model: str = "custom-model" + + @classmethod + def supported_models(cls) -> list[str]: + return ["custom-model"] + + async def generate_content_async(self, llm_request, stream=False): + yield None + + custom_llm = SerializableCustomLlm() + agent = Agent(model=custom_llm, name="test_agent") + + # Should not raise - the custom LLM is serializable + cli_deploy._validate_agent_object(agent, "root_agent") + + def test_raises_on_non_serializable_custom_basellm(self) -> None: + """Should raise when custom BaseLlm cannot be serialized.""" + from unittest.mock import MagicMock + from unittest.mock import patch + + from google.adk.agents import Agent + from google.adk.models import BaseLlm + + # Create a simple custom BaseLlm + class CustomLlm(BaseLlm): + model: str = "custom-model" + + @classmethod + def supported_models(cls) -> list[str]: + return ["custom-model"] + + async def generate_content_async(self, llm_request, stream=False): + yield None + + custom_llm = CustomLlm() + agent = Agent(model=custom_llm, name="test_agent") + + # Mock cloudpickle.dumps to raise an exception + with patch("cloudpickle.dumps") as mock_dumps: + mock_dumps.side_effect = Exception("Cannot pickle this object") + + with pytest.raises(click.ClickException) as exc_info: + cli_deploy._validate_agent_object(agent, "root_agent") + assert "cannot be serialized" in str(exc_info.value) + assert "CustomLlm" in str(exc_info.value) + + def test_skips_builtin_models(self) -> None: + """Should skip serialization check for built-in ADK models.""" + from google.adk.agents import Agent + from google.adk.models import Gemini + + # Create agent with built-in Gemini model + gemini = Gemini(model="gemini-2.0-flash") + agent = Agent(model=gemini, name="test_agent") + + # Should not raise - built-in models don't need serialization check + cli_deploy._validate_agent_object(agent, "root_agent") + + +# _AGENT_ENGINE_APP_TEMPLATE tests +class TestAgentEngineAppTemplate: + """Tests for the Agent Engine app template generation.""" + + def test_template_includes_cloudpickle_imports(self) -> None: + """Template should include cloudpickle imports for serialization fix.""" + template = cli_deploy._AGENT_ENGINE_APP_TEMPLATE + assert "import cloudpickle" in template + assert "import sys" in template + + def test_template_registers_agent_module_for_pickle_by_value(self) -> None: + """Template should register agent module for pickle-by-value.""" + template = cli_deploy._AGENT_ENGINE_APP_TEMPLATE + assert "cloudpickle.register_pickle_by_value(_agent_module)" in template + + def test_template_registers_submodules_for_pickle_by_value(self) -> None: + """Template should register submodules (clients/, tools/) for pickle-by-value.""" + template = cli_deploy._AGENT_ENGINE_APP_TEMPLATE + # Verify it iterates over sys.modules to find submodules + assert "for name, module in list(sys.modules.items())" in template + # Verify it registers submodules that match the agent package + assert "cloudpickle.register_pickle_by_value(module)" in template + + def test_template_handles_non_registerable_modules(self) -> None: + """Template should handle modules that can't be registered gracefully.""" + template = cli_deploy._AGENT_ENGINE_APP_TEMPLATE + # Verify it catches exceptions from register_pickle_by_value + assert "except Exception:" in template + assert "pass" in template + + def test_template_skips_cloudpickle_for_config_agents(self) -> None: + """Config agents should not have cloudpickle registration.""" + template = cli_deploy._AGENT_ENGINE_APP_TEMPLATE + # The cloudpickle registration is inside the 'else' block for non-config agents + # Config agents use the 'if {is_config_agent}' branch which doesn't have it + lines = template.split("\n") + in_config_agent_block = False + config_agent_has_cloudpickle = False + + for line in lines: + if "if {is_config_agent}:" in line: + in_config_agent_block = True + elif in_config_agent_block and line.strip().startswith("else:"): + in_config_agent_block = False + elif ( + in_config_agent_block + and "cloudpickle.register_pickle_by_value" in line + ): + config_agent_has_cloudpickle = True + + assert ( + not config_agent_has_cloudpickle + ), "Config agents should not have cloudpickle registration" + + +# Cloudpickle serialization integration tests +class TestCloudpickleSerializationFix: + """Integration tests for the cloudpickle serialization fix.""" + + def test_custom_basellm_in_submodule_can_be_serialized( + self, tmp_path: Path + ) -> None: + """Custom BaseLlm defined in a submodule should be serializable.""" + import cloudpickle + from google.adk.models import BaseLlm + + # Create a custom BaseLlm class + class SubmoduleCustomLlm(BaseLlm): + model: str = "custom-model" + + @classmethod + def supported_models(cls) -> list[str]: + return ["custom-model"] + + async def generate_content_async(self, llm_request, stream=False): + yield None + + # Create an instance + custom_llm = SubmoduleCustomLlm() + + # Simulate the fix by registering the module + # (In real deployment, the template does this automatically) + import types + + fake_module = types.ModuleType("fake_agent_module") + fake_module.SubmoduleCustomLlm = SubmoduleCustomLlm + sys.modules["fake_agent_module"] = fake_module + + try: + cloudpickle.register_pickle_by_value(fake_module) + + # Serialize and deserialize + serialized = cloudpickle.dumps(custom_llm) + deserialized = cloudpickle.loads(serialized) + + # Verify the class is intact + assert type(deserialized).__name__ == "SubmoduleCustomLlm" + assert deserialized.model == "custom-model" + finally: + # Cleanup + sys.modules.pop("fake_agent_module", None) + + def test_agent_with_custom_basellm_can_be_serialized(self) -> None: + """Agent with custom BaseLlm should be serializable after fix.""" + import cloudpickle + from google.adk.agents import Agent + from google.adk.models import BaseLlm + + class SerializableCustomLlm(BaseLlm): + model: str = "test-model" + + @classmethod + def supported_models(cls) -> list[str]: + return ["test-model"] + + async def generate_content_async(self, llm_request, stream=False): + yield None + + custom_llm = SerializableCustomLlm() + agent = Agent(model=custom_llm, name="test_agent") + + # Should be able to serialize and deserialize + serialized = cloudpickle.dumps(agent) + deserialized = cloudpickle.loads(serialized) + + assert deserialized.name == "test_agent" + assert type(deserialized.model).__name__ == "SerializableCustomLlm"