diff --git a/src/azure-cli-core/azure/cli/core/__init__.py b/src/azure-cli-core/azure/cli/core/__init__.py index 3251be00ab4..2b4ae31fca3 100644 --- a/src/azure-cli-core/azure/cli/core/__init__.py +++ b/src/azure-cli-core/azure/cli/core/__init__.py @@ -9,6 +9,8 @@ import os import sys import timeit +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor from knack.cli import CLI from knack.commands import CLICommandsLoader @@ -32,6 +34,10 @@ ALWAYS_LOADED_MODULES = [] # Extensions that will always be loaded if installed. They don't expose commands but hook into CLI core. ALWAYS_LOADED_EXTENSIONS = ['azext_ai_examples', 'azext_next'] +# Timeout (in seconds) for loading a single module. Acts as a safety valve to prevent indefinite hangs +MODULE_LOAD_TIMEOUT_SECONDS = 30 +# Maximum number of worker threads for parallel module loading. +MAX_WORKER_THREAD_COUNT = 4 def _configure_knack(): @@ -195,6 +201,17 @@ def _configure_style(self): format_styled_text.theme = theme +class ModuleLoadResult: # pylint: disable=too-few-public-methods + def __init__(self, module_name, command_table, group_table, elapsed_time, error=None, traceback_str=None, command_loader=None): + self.module_name = module_name + self.command_table = command_table + self.group_table = group_table + self.elapsed_time = elapsed_time + self.error = error + self.traceback_str = traceback_str + self.command_loader = command_loader + + class MainCommandsLoader(CLICommandsLoader): # Format string for pretty-print the command module table @@ -221,11 +238,11 @@ def load_command_table(self, args): import pkgutil import traceback from azure.cli.core.commands import ( - _load_module_command_loader, _load_extension_command_loader, BLOCKED_MODS, ExtensionCommandSource) + _load_extension_command_loader, ExtensionCommandSource) from azure.cli.core.extension import ( get_extensions, get_extension_path, get_extension_modname) from azure.cli.core.breaking_change import ( - import_core_breaking_changes, import_module_breaking_changes, import_extension_breaking_changes) + import_core_breaking_changes, import_extension_breaking_changes) def _update_command_table_from_modules(args, command_modules=None): """Loads command tables from modules and merge into the main command table. @@ -253,38 +270,10 @@ def _update_command_table_from_modules(args, command_modules=None): except ImportError as e: logger.warning(e) - count = 0 - cumulative_elapsed_time = 0 - cumulative_group_count = 0 - cumulative_command_count = 0 - logger.debug("Loading command modules:") - logger.debug(self.header_mod) + results = self._load_modules(args, command_modules) - for mod in [m for m in command_modules if m not in BLOCKED_MODS]: - try: - start_time = timeit.default_timer() - module_command_table, module_group_table = _load_module_command_loader(self, args, mod) - import_module_breaking_changes(mod) - for cmd in module_command_table.values(): - cmd.command_source = mod - self.command_table.update(module_command_table) - self.command_group_table.update(module_group_table) - - elapsed_time = timeit.default_timer() - start_time - logger.debug(self.item_format_string, mod, elapsed_time, - len(module_group_table), len(module_command_table)) - count += 1 - cumulative_elapsed_time += elapsed_time - cumulative_group_count += len(module_group_table) - cumulative_command_count += len(module_command_table) - except Exception as ex: # pylint: disable=broad-except - # Changing this error message requires updating CI script that checks for failed - # module loading. - from azure.cli.core import telemetry - logger.error("Error loading command module '%s': %s", mod, ex) - telemetry.set_exception(exception=ex, fault_type='module-load-error-' + mod, - summary='Error loading module: {}'.format(mod)) - logger.debug(traceback.format_exc()) + count, cumulative_elapsed_time, cumulative_group_count, cumulative_command_count = \ + self._process_results_with_timing(results) # Summary line logger.debug(self.item_format_string, "Total ({})".format(count), cumulative_elapsed_time, @@ -358,7 +347,7 @@ def _filter_modname(extensions): # from an extension requires this map to be up-to-date. # self._mod_to_ext_map[ext_mod] = ext_name start_time = timeit.default_timer() - extension_command_table, extension_group_table = \ + extension_command_table, extension_group_table, _ = \ _load_extension_command_loader(self, args, ext_mod) import_extension_breaking_changes(ext_mod) @@ -561,6 +550,99 @@ def load_arguments(self, command=None): self.extra_argument_registry.update(loader.extra_argument_registry) loader._update_command_definitions() # pylint: disable=protected-access + def _load_modules(self, args, command_modules): + """Load command modules using ThreadPoolExecutor with timeout protection.""" + from azure.cli.core.commands import BLOCKED_MODS + + results = [] + with ThreadPoolExecutor(max_workers=MAX_WORKER_THREAD_COUNT) as executor: + future_to_module = {executor.submit(self._load_single_module, mod, args): mod + for mod in command_modules if mod not in BLOCKED_MODS} + + for future in concurrent.futures.as_completed(future_to_module): + try: + result = future.result(timeout=MODULE_LOAD_TIMEOUT_SECONDS) + results.append(result) + except concurrent.futures.TimeoutError: + mod = future_to_module[future] + logger.warning("Module '%s' load timeout after %s seconds", mod, MODULE_LOAD_TIMEOUT_SECONDS) + results.append(ModuleLoadResult(mod, {}, {}, 0, + Exception(f"Module '{mod}' load timeout"))) + except (ImportError, AttributeError, TypeError, ValueError) as ex: + mod = future_to_module[future] + logger.warning("Module '%s' load failed: %s", mod, ex) + results.append(ModuleLoadResult(mod, {}, {}, 0, ex)) + except Exception as ex: # pylint: disable=broad-exception-caught + mod = future_to_module[future] + logger.warning("Module '%s' load failed with unexpected exception: %s", mod, ex) + results.append(ModuleLoadResult(mod, {}, {}, 0, ex)) + + return results + + def _load_single_module(self, mod, args): + from azure.cli.core.breaking_change import import_module_breaking_changes + from azure.cli.core.commands import _load_module_command_loader + import traceback + try: + start_time = timeit.default_timer() + module_command_table, module_group_table, command_loader = _load_module_command_loader(self, args, mod) + import_module_breaking_changes(mod) + elapsed_time = timeit.default_timer() - start_time + return ModuleLoadResult(mod, module_command_table, module_group_table, elapsed_time, command_loader=command_loader) + except Exception as ex: # pylint: disable=broad-except + tb_str = traceback.format_exc() + return ModuleLoadResult(mod, {}, {}, 0, ex, tb_str) + + def _handle_module_load_error(self, result): + """Handle errors that occurred during module loading.""" + from azure.cli.core import telemetry + + logger.error("Error loading command module '%s': %s", result.module_name, result.error) + telemetry.set_exception(exception=result.error, + fault_type='module-load-error-' + result.module_name, + summary='Error loading module: {}'.format(result.module_name)) + if result.traceback_str: + logger.debug(result.traceback_str) + + def _process_successful_load(self, result): + """Process successfully loaded module results.""" + if result.command_loader: + self.loaders.append(result.command_loader) + + for cmd in result.command_table: + self.cmd_to_loader_map[cmd] = [result.command_loader] + + for cmd in result.command_table.values(): + cmd.command_source = result.module_name + + self.command_table.update(result.command_table) + self.command_group_table.update(result.group_table) + + logger.debug(self.item_format_string, result.module_name, result.elapsed_time, + len(result.group_table), len(result.command_table)) + + def _process_results_with_timing(self, results): + """Process pre-loaded module results with timing and progress reporting.""" + logger.debug("Loading command modules:") + logger.debug(self.header_mod) + + count = 0 + cumulative_elapsed_time = 0 + cumulative_group_count = 0 + cumulative_command_count = 0 + + for result in results: + if result.error: + self._handle_module_load_error(result) + else: + self._process_successful_load(result) + count += 1 + cumulative_elapsed_time += result.elapsed_time + cumulative_group_count += len(result.group_table) + cumulative_command_count += len(result.command_table) + + return count, cumulative_elapsed_time, cumulative_group_count, cumulative_command_count + class CommandIndex: diff --git a/src/azure-cli-core/azure/cli/core/commands/__init__.py b/src/azure-cli-core/azure/cli/core/commands/__init__.py index 696b6093f5d..1764fbdf25e 100644 --- a/src/azure-cli-core/azure/cli/core/commands/__init__.py +++ b/src/azure-cli-core/azure/cli/core/commands/__init__.py @@ -1134,22 +1134,17 @@ def _load_command_loader(loader, args, name, prefix): logger.debug("Module '%s' is missing `get_command_loader` entry.", name) command_table = {} + command_loader = None if loader_cls: command_loader = loader_cls(cli_ctx=loader.cli_ctx) - loader.loaders.append(command_loader) # This will be used by interactive if command_loader.supported_resource_type(): command_table = command_loader.load_command_table(args) - if command_table: - for cmd in list(command_table.keys()): - # TODO: If desired to for extension to patch module, this can be uncommented - # if loader.cmd_to_loader_map.get(cmd): - # loader.cmd_to_loader_map[cmd].append(command_loader) - # else: - loader.cmd_to_loader_map[cmd] = [command_loader] else: logger.debug("Module '%s' is missing `COMMAND_LOADER_CLS` entry.", name) - return command_table, command_loader.command_group_table + + group_table = command_loader.command_group_table if command_loader else {} + return command_table, group_table, command_loader def _load_extension_command_loader(loader, args, ext): diff --git a/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py b/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py index 21d03d47b5e..41cf1ea1af4 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py @@ -230,7 +230,7 @@ def load_command_table(self, args): if command_table: module_command_table.update(command_table) loader.loaders.append(command_loader) # this will be used later by the load_arguments method - return module_command_table, command_loader.command_group_table + return module_command_table, command_loader.command_group_table, command_loader expected_command_index = {'hello': ['azure.cli.command_modules.hello', 'azext_hello2', 'azext_hello1'], 'extra': ['azure.cli.command_modules.extra']} diff --git a/src/azure-cli-core/azure/cli/core/tests/test_command_table_integrity.py b/src/azure-cli-core/azure/cli/core/tests/test_command_table_integrity.py new file mode 100644 index 00000000000..28fa52dd355 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/tests/test_command_table_integrity.py @@ -0,0 +1,57 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import unittest + +from azure.cli.core.mock import DummyCli +from azure.cli.core import MainCommandsLoader + + +class CommandTableIntegrityTest(unittest.TestCase): + + def setUp(self): + self.cli_ctx = DummyCli() + + def test_command_table_integrity(self): + """Test command table loading produces valid, complete results.""" + + # Load command table using current implementation + loader = MainCommandsLoader(self.cli_ctx) + loader.load_command_table([]) + + # Test invariants that should always hold: + + # 1. No corruption/duplicates + command_names = list(loader.command_table.keys()) + unique_command_names = set(command_names) + self.assertEqual(len(unique_command_names), len(command_names), "No duplicate commands") + + # 2. Core functionality exists (high-level groups that should always exist) + core_groups = ['vm', 'network', 'resource', 'account', 'group'] + existing_groups = {cmd.split()[0] for cmd in loader.command_table.keys() if ' ' in cmd} + missing_core = [group for group in core_groups if group not in existing_groups] + self.assertEqual(len(missing_core), 0, f"Missing core command groups: {missing_core}") + + # 3. Structural integrity + commands_without_source = [] + for cmd_name, cmd_obj in loader.command_table.items(): + if not hasattr(cmd_obj, 'command_source') or not cmd_obj.command_source: + commands_without_source.append(cmd_name) + + self.assertEqual(len(commands_without_source), 0, + f"Commands missing source: {commands_without_source[:5]}...") + + # 4. Basic sanity - we loaded SOMETHING + self.assertGreater(len(loader.command_table), 0, "Commands were loaded") + self.assertGreater(len(loader.command_group_table), 0, "Groups were loaded") + + # 5. Verify core groups are properly represented + found_core_groups = sorted(existing_groups & set(core_groups)) + self.assertGreaterEqual(len(found_core_groups), 3, + f"At least 3 core command groups should be present, found: {found_core_groups}") + + +if __name__ == '__main__': + unittest.main() diff --git a/src/azure-cli-core/azure/cli/core/tests/test_parser.py b/src/azure-cli-core/azure/cli/core/tests/test_parser.py index bff9c8ddc62..d584aafe383 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_parser.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_parser.py @@ -188,7 +188,7 @@ def load_command_table(self, args): if command_table: module_command_table.update(command_table) loader.loaders.append(command_loader) # this will be used later by the load_arguments method - return module_command_table, command_loader.command_group_table + return module_command_table, command_loader.command_group_table, command_loader @mock.patch('importlib.import_module', _mock_import_lib) @mock.patch('pkgutil.iter_modules', _mock_iter_modules)