diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index bccaa46cb..34fc2fe9e 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -22,11 +22,10 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v6 - with: - version: "0.5.30" - name: sync uv run: | + uv venv --seed uv sync diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 66dfd5eb4..6ddfe763a 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -25,12 +25,117 @@ from codeflash.models.models import FunctionSource +class GlobalFunctionCollector(cst.CSTVisitor): + """Collects all module-level function definitions (not inside classes or other functions).""" + + def __init__(self) -> None: + super().__init__() + self.functions: dict[str, cst.FunctionDef] = {} + self.function_order: list[str] = [] + self.scope_depth = 0 + + def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: + if self.scope_depth == 0: + # Module-level function + name = node.name.value + self.functions[name] = node + if name not in self.function_order: + self.function_order.append(name) + self.scope_depth += 1 + return True + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002 + self.scope_depth -= 1 + + def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002 + self.scope_depth += 1 + return True + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002 + self.scope_depth -= 1 + + +class GlobalFunctionTransformer(cst.CSTTransformer): + """Transforms/adds module-level functions from the new file to the original file.""" + + def __init__(self, new_functions: dict[str, cst.FunctionDef], new_function_order: list[str]) -> None: + super().__init__() + self.new_functions = new_functions + self.new_function_order = new_function_order + self.processed_functions: set[str] = set() + self.scope_depth = 0 + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002 + self.scope_depth += 1 + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + self.scope_depth -= 1 + if self.scope_depth > 0: + return updated_node + + # Check if this is a module-level function we need to replace + name = original_node.name.value + if name in self.new_functions: + self.processed_functions.add(name) + return self.new_functions[name] + return updated_node + + def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002 + self.scope_depth += 1 + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002 + self.scope_depth -= 1 + return updated_node + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 + # Add any new functions that weren't in the original file + new_statements = list(updated_node.body) + + functions_to_append = [ + self.new_functions[name] + for name in self.new_function_order + if name not in self.processed_functions and name in self.new_functions + ] + + if functions_to_append: + # Find the position of the last function or class definition + insert_index = find_insertion_index_after_imports(updated_node) + for i, stmt in enumerate(new_statements): + if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): + insert_index = i + 1 + + # Add empty line before each new function + function_nodes = [] + for func in functions_to_append: + func_with_empty_line = func.with_changes(leading_lines=[cst.EmptyLine(), *func.leading_lines]) + function_nodes.append(func_with_empty_line) + + new_statements = list(chain(new_statements[:insert_index], function_nodes, new_statements[insert_index:])) + + return updated_node.with_changes(body=new_statements) + + +def collect_referenced_names(node: cst.CSTNode) -> set[str]: + """Collect all names referenced in a CST node using recursive traversal.""" + names: set[str] = set() + + def _collect(n: cst.CSTNode) -> None: + if isinstance(n, cst.Name): + names.add(n.value) + # Recursively process all children + for child in n.children: + _collect(child) + + _collect(node) + return names + + class GlobalAssignmentCollector(cst.CSTVisitor): """Collects all global assignment statements.""" def __init__(self) -> None: super().__init__() - self.assignments: dict[str, cst.Assign] = {} + self.assignments: dict[str, cst.Assign | cst.AnnAssign] = {} self.assignment_order: list[str] = [] # Track scope depth to identify global assignments self.scope_depth = 0 @@ -72,6 +177,21 @@ def visit_Assign(self, node: cst.Assign) -> Optional[bool]: self.assignment_order.append(name) return True + def visit_AnnAssign(self, node: cst.AnnAssign) -> Optional[bool]: + # Handle annotated assignments like: _CACHE: Dict[str, int] = {} + # Only process module-level annotated assignments with a value + if ( + self.scope_depth == 0 + and self.if_else_depth == 0 + and isinstance(node.target, cst.Name) + and node.value is not None + ): + name = node.target.value + self.assignments[name] = node + if name not in self.assignment_order: + self.assignment_order.append(name) + return True + def find_insertion_index_after_imports(node: cst.Module) -> int: """Find the position of the last import statement in the top-level of the module.""" @@ -103,7 +223,7 @@ def find_insertion_index_after_imports(node: cst.Module) -> int: class GlobalAssignmentTransformer(cst.CSTTransformer): """Transforms global assignments in the original file with those from the new file.""" - def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: list[str]) -> None: + def __init__(self, new_assignments: dict[str, cst.Assign | cst.AnnAssign], new_assignment_order: list[str]) -> None: super().__init__() self.new_assignments = new_assignments self.new_assignment_order = new_assignment_order @@ -150,38 +270,120 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c return updated_node + def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.CSTNode: + if self.scope_depth > 0 or self.if_else_depth > 0: + return updated_node + + # Check if this is a global annotated assignment we need to replace + if isinstance(original_node.target, cst.Name): + name = original_node.target.value + if name in self.new_assignments: + self.processed_assignments.add(name) + return self.new_assignments[name] + + return updated_node + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 # Add any new assignments that weren't in the original file new_statements = list(updated_node.body) # Find assignments to append assignments_to_append = [ - self.new_assignments[name] + (name, self.new_assignments[name]) for name in self.new_assignment_order if name not in self.processed_assignments and name in self.new_assignments ] - if assignments_to_append: - # after last top-level imports - insert_index = find_insertion_index_after_imports(updated_node) + if not assignments_to_append: + return updated_node.with_changes(body=new_statements) + + # Collect all class and function names defined in the module + # These are the names that assignments might reference + module_defined_names: set[str] = set() + for stmt in new_statements: + if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)): + module_defined_names.add(stmt.name.value) + + # Partition assignments: those that reference module definitions go at the end, + # those that don't can go right after imports + assignments_after_imports: list[tuple[str, cst.Assign | cst.AnnAssign]] = [] + assignments_after_definitions: list[tuple[str, cst.Assign | cst.AnnAssign]] = [] + + for name, assignment in assignments_to_append: + # Get the value being assigned + if isinstance(assignment, (cst.Assign, cst.AnnAssign)) and assignment.value is not None: + value_node = assignment.value + else: + # No value to analyze, safe to place after imports + assignments_after_imports.append((name, assignment)) + continue + + # Collect names referenced in the assignment value + referenced_names = collect_referenced_names(value_node) + + # Check if any referenced names are module-level definitions + if referenced_names & module_defined_names: + # This assignment references a class/function, place it after definitions + assignments_after_definitions.append((name, assignment)) + else: + # Safe to place right after imports + assignments_after_imports.append((name, assignment)) + # Insert assignments that don't depend on module definitions right after imports + if assignments_after_imports: + insert_index = find_insertion_index_after_imports(updated_node) assignment_lines = [ cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) - for assignment in assignments_to_append + for _, assignment in assignments_after_imports ] + new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) + + # Insert assignments that depend on module definitions after all class/function definitions + if assignments_after_definitions: + # Find the position after the last function or class definition + insert_index = find_insertion_index_after_imports(cst.Module(body=new_statements)) + for i, stmt in enumerate(new_statements): + if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): + insert_index = i + 1 + assignment_lines = [ + cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) + for _, assignment in assignments_after_definitions + ] new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) - # Add a blank line after the last assignment if needed - after_index = insert_index + len(assignment_lines) - if after_index < len(new_statements): - next_stmt = new_statements[after_index] - # If there's no empty line, add one - has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines) - if not has_empty: - new_statements[after_index] = next_stmt.with_changes( - leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines] - ) + return updated_node.with_changes(body=new_statements) + + +class GlobalStatementTransformer(cst.CSTTransformer): + """Transformer that appends global statements at the end of the module. + + This ensures that global statements (like function calls at module level) are placed + after all functions, classes, and assignments they might reference, preventing NameError + at module load time. + + This transformer should be run LAST after GlobalFunctionTransformer and + GlobalAssignmentTransformer have already added their content. + """ + + def __init__(self, global_statements: list[cst.SimpleStatementLine]) -> None: + super().__init__() + self.global_statements = global_statements + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 + if not self.global_statements: + return updated_node + + new_statements = list(updated_node.body) + + # Add empty line before each statement for readability + statement_lines = [ + stmt.with_changes(leading_lines=[cst.EmptyLine(), *stmt.leading_lines]) for stmt in self.global_statements + ] + + # Append statements at the end of the module + # This ensures they come after all functions, classes, and assignments + new_statements.extend(statement_lines) return updated_node.with_changes(body=new_statements) @@ -213,8 +415,8 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: AR def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: if not self.in_function_or_class: for statement in node.body: - # Skip imports - if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)): + # Skip imports and assignments (both regular and annotated) + if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign, cst.AnnAssign)): self.global_statements.append(node) break @@ -309,40 +511,6 @@ def visit_Try(self, node: cst.Try) -> None: self._collect_imports_from_block(node.body) -class ImportInserter(cst.CSTTransformer): - """Transformer that inserts global statements after the last import.""" - - def __init__(self, global_statements: list[cst.SimpleStatementLine], last_import_line: int) -> None: - super().__init__() - self.global_statements = global_statements - self.last_import_line = last_import_line - self.current_line = 0 - self.inserted = False - - def leave_SimpleStatementLine( - self, - original_node: cst.SimpleStatementLine, # noqa: ARG002 - updated_node: cst.SimpleStatementLine, - ) -> cst.Module: - self.current_line += 1 - - # If we're right after the last import and haven't inserted yet - if self.current_line == self.last_import_line and not self.inserted: - self.inserted = True - return cst.Module(body=[updated_node, *self.global_statements]) - - return cst.Module(body=[updated_node]) - - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 - # If there were no imports, add at the beginning of the module - if self.last_import_line == 0 and not self.inserted: - updated_body = list(updated_node.body) - for stmt in reversed(self.global_statements): - updated_body.insert(0, stmt) - return updated_node.with_changes(body=updated_body) - return updated_node - - def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]: """Extract global statements from source code.""" module = cst.parse_module(source_code) @@ -394,34 +562,58 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: continue unique_global_statements.append(stmt) - mod_dst_code = dst_module_code - # Insert unique global statements if any - if unique_global_statements: - last_import_line = find_last_import_line(dst_module_code) - # Reuse already-parsed dst_module - transformer = ImportInserter(unique_global_statements, last_import_line) - # Use visit inplace, don't parse again - modified_module = dst_module.visit(transformer) - mod_dst_code = modified_module.code - # Parse the code after insertion - original_module = cst.parse_module(mod_dst_code) - else: - # No new statements to insert, reuse already-parsed dst_module - original_module = dst_module + # Reuse already-parsed dst_module + original_module = dst_module # Parse the src_module_code once only (already done above: src_module) # Collect assignments from the new file - new_collector = GlobalAssignmentCollector() - src_module.visit(new_collector) - # Only create transformer if there are assignments to insert/transform - if not new_collector.assignments: # nothing to transform - return mod_dst_code + new_assignment_collector = GlobalAssignmentCollector() + src_module.visit(new_assignment_collector) + + # Collect module-level functions from both source and destination + src_function_collector = GlobalFunctionCollector() + src_module.visit(src_function_collector) + + dst_function_collector = GlobalFunctionCollector() + original_module.visit(dst_function_collector) + + # Filter out functions that already exist in the destination (only add truly new functions) + new_functions = { + name: func + for name, func in src_function_collector.functions.items() + if name not in dst_function_collector.functions + } + new_function_order = [name for name in src_function_collector.function_order if name in new_functions] + + # If there are no assignments, no new functions, and no global statements, return unchanged + if not new_assignment_collector.assignments and not new_functions and not unique_global_statements: + return dst_module_code + + # The order of transformations matters: + # 1. Functions first - so assignments and statements can reference them + # 2. Assignments second - so they come after functions but before statements + # 3. Global statements last - so they can reference both functions and assignments + + # Transform functions if any + if new_functions: + function_transformer = GlobalFunctionTransformer(new_functions, new_function_order) + original_module = original_module.visit(function_transformer) - # Transform the original destination module - transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order) - transformed_module = original_module.visit(transformer) + # Transform assignments if any + if new_assignment_collector.assignments: + transformer = GlobalAssignmentTransformer( + new_assignment_collector.assignments, new_assignment_collector.assignment_order + ) + original_module = original_module.visit(transformer) + + # Insert global statements (like function calls at module level) LAST, + # after all functions and assignments are added, to ensure they can reference any + # functions or variables defined in the module + if unique_global_statements: + statement_transformer = GlobalStatementTransformer(unique_global_statements) + original_module = original_module.visit(statement_transformer) - return transformed_module.code + return original_module.code def resolve_star_import(module_name: str, project_root: Path) -> set[str]: diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 164440f9b..4c38f37e5 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -5,6 +5,7 @@ import os from collections import defaultdict from itertools import chain +from pathlib import Path from typing import TYPE_CHECKING, cast import libcst as cst @@ -16,6 +17,7 @@ from codeflash.context.unused_definition_remover import ( collect_top_level_defs_with_usages, extract_names_from_targets, + get_section_names, remove_unused_definitions_by_function_names, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001 @@ -29,14 +31,44 @@ from codeflash.optimization.function_context import belongs_to_function_qualified if TYPE_CHECKING: - from pathlib import Path - from jedi.api.classes import Name from libcst import CSTNode from codeflash.context.unused_definition_remover import UsageInfo +def build_testgen_context( + helpers_of_fto_dict: dict[Path, set[FunctionSource]], + helpers_of_helpers_dict: dict[Path, set[FunctionSource]], + project_root_path: Path, + remove_docstrings: bool, # noqa: FBT001 + include_imported_classes: bool, # noqa: FBT001 +) -> CodeStringsMarkdown: + """Build testgen context with optional imported class definitions and external base inits.""" + testgen_context = extract_code_markdown_context_from_files( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=remove_docstrings, + code_context_type=CodeContextType.TESTGEN, + ) + + if include_imported_classes: + imported_class_context = get_imported_class_definitions(testgen_context, project_root_path) + if imported_class_context.code_strings: + testgen_context = CodeStringsMarkdown( + code_strings=testgen_context.code_strings + imported_class_context.code_strings + ) + + external_base_inits = get_external_base_class_inits(testgen_context, project_root_path) + if external_base_inits.code_strings: + testgen_context = CodeStringsMarkdown( + code_strings=testgen_context.code_strings + external_base_inits.code_strings + ) + + return testgen_context + + def get_code_optimization_context( function_to_optimize: FunctionToOptimize, project_root_path: Path, @@ -120,55 +152,37 @@ def get_code_optimization_context( logger.debug("Code context has exceeded token limit, removing read-only code") read_only_context_code = "" - # Extract code context for testgen - testgen_context = extract_code_markdown_context_from_files( + # Extract code context for testgen with progressive fallback for token limits + # Try in order: full context -> remove docstrings -> remove imported classes + testgen_context = build_testgen_context( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=False, - code_context_type=CodeContextType.TESTGEN, + include_imported_classes=True, ) - # Extract class definitions for imported types from project modules - # This helps the LLM understand class constructors and structure - imported_class_context = get_imported_class_definitions(testgen_context, project_root_path) - if imported_class_context.code_strings: - # Merge imported class definitions into testgen context - testgen_context = CodeStringsMarkdown( - code_strings=testgen_context.code_strings + imported_class_context.code_strings - ) - - testgen_markdown_code = testgen_context.markdown - testgen_code_token_length = encoded_tokens_len(testgen_markdown_code) - if testgen_code_token_length > testgen_token_limit: - # First try removing docstrings - testgen_context = extract_code_markdown_context_from_files( + if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: + logger.debug("Testgen context exceeded token limit, removing docstrings") + testgen_context = build_testgen_context( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True, - code_context_type=CodeContextType.TESTGEN, + include_imported_classes=True, ) - # Re-extract imported classes (they may still fit) - imported_class_context = get_imported_class_definitions(testgen_context, project_root_path) - if imported_class_context.code_strings: - testgen_context = CodeStringsMarkdown( - code_strings=testgen_context.code_strings + imported_class_context.code_strings - ) - testgen_markdown_code = testgen_context.markdown - testgen_code_token_length = encoded_tokens_len(testgen_markdown_code) - if testgen_code_token_length > testgen_token_limit: - # If still over limit, try without imported class definitions - testgen_context = extract_code_markdown_context_from_files( + + if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: + logger.debug("Testgen context still exceeded token limit, removing imported class definitions") + testgen_context = build_testgen_context( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True, - code_context_type=CodeContextType.TESTGEN, + include_imported_classes=False, ) - testgen_markdown_code = testgen_context.markdown - testgen_code_token_length = encoded_tokens_len(testgen_markdown_code) - if testgen_code_token_length > testgen_token_limit: + + if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: raise ValueError("Testgen code context has exceeded token limit, cannot proceed") code_hash_context = hashing_code_context.markdown code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() @@ -184,114 +198,6 @@ def get_code_optimization_context( ) -def extract_code_string_context_from_files( - helpers_of_fto: dict[Path, set[FunctionSource]], - helpers_of_helpers: dict[Path, set[FunctionSource]], - project_root_path: Path, - remove_docstrings: bool = False, # noqa: FBT001, FBT002 - code_context_type: CodeContextType = CodeContextType.READ_ONLY, -) -> CodeString: - """Extract code context from files containing target functions and their helpers. - This function processes two sets of files: - 1. Files containing the function to optimize (fto) and their first-degree helpers - 2. Files containing only helpers of helpers (with no overlap with the first set). - - For each file, it extracts relevant code based on the specified context type, adds necessary - imports, and combines them. - - Args: - ---- - helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers - helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions - project_root_path: Root path of the project - remove_docstrings: Whether to remove docstrings from the extracted code - code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) - - Returns: - ------- - CodeString containing the extracted code context with necessary imports - - """ # noqa: D205 - # Rearrange to remove overlaps, so we only access each file path once - helpers_of_helpers_no_overlap = defaultdict(set) - for file_path, function_sources in helpers_of_helpers.items(): - if file_path in helpers_of_fto: - # Remove duplicates within the same file path, in case a helper of helper is also a helper of fto - helpers_of_helpers[file_path] -= helpers_of_fto[file_path] - else: - helpers_of_helpers_no_overlap[file_path] = function_sources - - final_code_string_context = "" - - # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files - for file_path, function_sources in helpers_of_fto.items(): - try: - original_code = file_path.read_text("utf8") - except Exception as e: - logger.exception(f"Error while parsing {file_path}: {e}") - continue - try: - qualified_function_names = {func.qualified_name for func in function_sources} - helpers_of_helpers_qualified_names = { - func.qualified_name for func in helpers_of_helpers.get(file_path, set()) - } - code_without_unused_defs = remove_unused_definitions_by_function_names( - original_code, qualified_function_names | helpers_of_helpers_qualified_names - ) - code_context = parse_code_and_prune_cst( - code_without_unused_defs, - code_context_type, - qualified_function_names, - helpers_of_helpers_qualified_names, - remove_docstrings, - ) - except ValueError as e: - logger.debug(f"Error while getting read-only code: {e}") - continue - if code_context.strip(): - final_code_string_context += f"\n{code_context}" - final_code_string_context = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=final_code_string_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())), - ) - if code_context_type == CodeContextType.READ_WRITABLE: - return CodeString(code=final_code_string_context) - # Extract code from file paths containing helpers of helpers - for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): - try: - original_code = file_path.read_text("utf8") - except Exception as e: - logger.exception(f"Error while parsing {file_path}: {e}") - continue - try: - qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} - code_without_unused_defs = remove_unused_definitions_by_function_names( - original_code, qualified_helper_function_names - ) - code_context = parse_code_and_prune_cst( - code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings - ) - except ValueError as e: - logger.debug(f"Error while getting read-only code: {e}") - continue - - if code_context.strip(): - final_code_string_context += f"\n{code_context}" - final_code_string_context = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=final_code_string_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), - ) - return CodeString(code=final_code_string_context) - - def extract_code_markdown_context_from_files( helpers_of_fto: dict[Path, set[FunctionSource]], helpers_of_helpers: dict[Path, set[FunctionSource]], @@ -526,6 +432,10 @@ class definitions for any classes imported from project modules. This helps the LLM understand the actual class structure (constructors, methods, inheritance) rather than just seeing import statements. + Also recursively extracts base classes when a class inherits from another class + in the same module, ensuring the full inheritance chain is available for + understanding constructor signatures. + Args: code_context: The already extracted code context containing imports project_root_path: Root path of the project @@ -568,6 +478,68 @@ class definitions for any classes imported from project modules. This helps class_code_strings: list[CodeString] = [] + module_cache: dict[Path, tuple[str, ast.Module]] = {} + + def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None: + if module_path in module_cache: + return module_cache[module_path] + try: + module_source = module_path.read_text(encoding="utf-8") + module_tree = ast.parse(module_source) + except Exception: + return None + else: + module_cache[module_path] = (module_source, module_tree) + return module_source, module_tree + + def extract_class_and_bases( + class_name: str, module_path: Path, module_source: str, module_tree: ast.Module + ) -> None: + """Extract a class and its base classes recursively from the same module.""" + # Skip if already extracted + if (module_path, class_name) in extracted_classes: + return + + # Find the class definition in the module + class_node = None + for node in ast.walk(module_tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + class_node = node + break + + if class_node is None: + return + + # First, recursively extract base classes from the same module + for base in class_node.bases: + base_name = None + if isinstance(base, ast.Name): + base_name = base.id + elif isinstance(base, ast.Attribute): + # For module.ClassName, we skip (cross-module inheritance) + continue + + if base_name and base_name not in existing_definitions: + # Check if base class is defined in the same module + extract_class_and_bases(base_name, module_path, module_source, module_tree) + + # Now extract this class (after its bases, so base classes appear first) + if (module_path, class_name) in extracted_classes: + return # Already added by another path + + lines = module_source.split("\n") + start_line = class_node.lineno + if class_node.decorator_list: + start_line = min(d.lineno for d in class_node.decorator_list) + class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno]) + + # Extract imports for the class + class_imports = extract_imports_for_class(module_tree, class_node, module_source) + full_source = class_imports + "\n\n" + class_source if class_imports else class_source + + class_code_strings.append(CodeString(code=full_source, file_path=module_path)) + extracted_classes.add((module_path, class_name)) + for name, module_name in imported_names.items(): # Skip if already defined in context if name in existing_definitions: @@ -593,40 +565,127 @@ class definitions for any classes imported from project modules. This helps if path_belongs_to_site_packages(module_path): continue - # Skip if we've already extracted this class - if (module_path, name) in extracted_classes: + # Get module source and tree + result = get_module_source_and_tree(module_path) + if result is None: continue + module_source, module_tree = result - # Parse the module to find the class definition - module_source = module_path.read_text(encoding="utf-8") - module_tree = ast.parse(module_source) + # Extract the class and its base classes + extract_class_and_bases(name, module_path, module_source, module_tree) - for node in ast.walk(module_tree): - if isinstance(node, ast.ClassDef) and node.name == name: - # Extract the class source code - lines = module_source.split("\n") - class_source = "\n".join(lines[node.lineno - 1 : node.end_lineno]) + except Exception: + logger.debug(f"Error extracting class definition for {name} from {module_name}") + continue - # Also extract any necessary imports for the class (base classes, type hints) - class_imports = _extract_imports_for_class(module_tree, node, module_source) + return CodeStringsMarkdown(code_strings=class_code_strings) - full_source = class_imports + "\n\n" + class_source if class_imports else class_source - class_code_strings.append(CodeString(code=full_source, file_path=module_path)) - extracted_classes.add((module_path, name)) - break +def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: + """Extract __init__ methods from external library base classes. - except Exception: - logger.debug(f"Error extracting class definition for {name} from {module_name}") + Scans the code context for classes that inherit from external libraries and extracts + just their __init__ methods. This helps the LLM understand constructor signatures + for mocking or instantiation. + """ + import importlib + import inspect + import textwrap + + all_code = "\n".join(cs.code for cs in code_context.code_strings) + + try: + tree = ast.parse(all_code) + except SyntaxError: + return CodeStringsMarkdown(code_strings=[]) + + imported_names: dict[str, str] = {} + external_bases: list[tuple[str, str]] = [] + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom) and node.module: + for alias in node.names: + if alias.name != "*": + imported_name = alias.asname if alias.asname else alias.name + imported_names[imported_name] = node.module + elif isinstance(node, ast.ClassDef): + for base in node.bases: + base_name = None + if isinstance(base, ast.Name): + base_name = base.id + elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name): + base_name = base.attr + + if base_name and base_name in imported_names: + module_name = imported_names[base_name] + if not _is_project_module(module_name, project_root_path): + external_bases.append((base_name, module_name)) + + if not external_bases: + return CodeStringsMarkdown(code_strings=[]) + + code_strings: list[CodeString] = [] + extracted: set[tuple[str, str]] = set() + + for base_name, module_name in external_bases: + if (module_name, base_name) in extracted: continue - return CodeStringsMarkdown(code_strings=class_code_strings) + try: + module = importlib.import_module(module_name) + base_class = getattr(module, base_name, None) + if base_class is None: + continue + + init_method = getattr(base_class, "__init__", None) + if init_method is None: + continue + try: + init_source = inspect.getsource(init_method) + init_source = textwrap.dedent(init_source) + class_file = Path(inspect.getfile(base_class)) + parts = class_file.parts + if "site-packages" in parts: + idx = parts.index("site-packages") + class_file = Path(*parts[idx + 1 :]) + except (OSError, TypeError): + continue + + class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ") + code_strings.append(CodeString(code=class_source, file_path=class_file)) + extracted.add((module_name, base_name)) + + except (ImportError, ModuleNotFoundError, AttributeError): + logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}") + continue -def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str: + return CodeStringsMarkdown(code_strings=code_strings) + + +def _is_project_module(module_name: str, project_root_path: Path) -> bool: + """Check if a module is part of the project (not external/stdlib).""" + import importlib.util + + try: + spec = importlib.util.find_spec(module_name) + except (ImportError, ModuleNotFoundError, ValueError): + return False + else: + if spec is None or spec.origin is None: + return False + module_path = Path(spec.origin) + # Check if the module is in site-packages (external dependency) + # This must be checked first because .venv/site-packages is under project root + if path_belongs_to_site_packages(module_path): + return False + # Check if the module is within the project root + return str(module_path).startswith(str(project_root_path) + os.sep) + + +def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str: """Extract import statements needed for a class definition. - This extracts imports for base classes and commonly used type annotations. + This extracts imports for base classes, decorators, and type annotations. """ needed_names: set[str] = set() @@ -638,35 +697,139 @@ def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef # For things like abc.ABC, we need the module name needed_names.add(base.value.id) + # Get decorator names (e.g., dataclass, field) + for decorator in class_node.decorator_list: + if isinstance(decorator, ast.Name): + needed_names.add(decorator.id) + elif isinstance(decorator, ast.Call): + if isinstance(decorator.func, ast.Name): + needed_names.add(decorator.func.id) + elif isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name): + needed_names.add(decorator.func.value.id) + + # Get type annotation names from class body (for dataclass fields) + for item in ast.walk(class_node): + if isinstance(item, ast.AnnAssign) and item.annotation: + collect_names_from_annotation(item.annotation, needed_names) + # Also check for field() calls which are common in dataclasses + if isinstance(item, ast.Call) and isinstance(item.func, ast.Name): + needed_names.add(item.func.id) + # Find imports that provide these names import_lines: list[str] = [] source_lines = module_source.split("\n") + added_imports: set[int] = set() # Track line numbers to avoid duplicates for node in module_tree.body: if isinstance(node, ast.Import): for alias in node.names: name = alias.asname if alias.asname else alias.name.split(".")[0] - if name in needed_names: + if name in needed_names and node.lineno not in added_imports: import_lines.append(source_lines[node.lineno - 1]) + added_imports.add(node.lineno) break elif isinstance(node, ast.ImportFrom): for alias in node.names: name = alias.asname if alias.asname else alias.name - if name in needed_names: + if name in needed_names and node.lineno not in added_imports: import_lines.append(source_lines[node.lineno - 1]) + added_imports.add(node.lineno) break return "\n".join(import_lines) +def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None: + """Recursively collect type annotation names from an AST node.""" + if isinstance(node, ast.Name): + names.add(node.id) + elif isinstance(node, ast.Subscript): + collect_names_from_annotation(node.value, names) + collect_names_from_annotation(node.slice, names) + elif isinstance(node, ast.Tuple): + for elt in node.elts: + collect_names_from_annotation(elt, names) + elif isinstance(node, ast.BinOp): # For Union types with | syntax + collect_names_from_annotation(node.left, names) + collect_names_from_annotation(node.right, names) + elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name): + names.add(node.value.id) + + def is_dunder_method(name: str) -> bool: return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__") -def get_section_names(node: cst.CSTNode) -> list[str]: - """Returns the section attribute names (e.g., body, orelse) for a given node if they exist.""" # noqa: D401 - possible_sections = ["body", "orelse", "finalbody", "handlers"] - return [sec for sec in possible_sections if hasattr(node, sec)] +class UsedNameCollector(cst.CSTVisitor): + """Collects all base names referenced in code (for import preservation).""" + + def __init__(self) -> None: + self.used_names: set[str] = set() + self.defined_names: set[str] = set() + + def visit_Name(self, node: cst.Name) -> None: + self.used_names.add(node.value) + + def visit_Attribute(self, node: cst.Attribute) -> bool | None: + base = node.value + while isinstance(base, cst.Attribute): + base = base.value + if isinstance(base, cst.Name): + self.used_names.add(base.value) + return True + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + self.defined_names.add(node.name.value) + return True + + def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: + self.defined_names.add(node.name.value) + return True + + def visit_Assign(self, node: cst.Assign) -> bool | None: + for target in node.targets: + names = extract_names_from_targets(target.target) + self.defined_names.update(names) + return True + + def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: + names = extract_names_from_targets(node.target) + self.defined_names.update(names) + return True + + def get_external_names(self) -> set[str]: + return self.used_names - self.defined_names - {"self", "cls"} + + +def get_imported_names(import_node: cst.Import | cst.ImportFrom) -> set[str]: + """Extract the names made available by an import statement.""" + names: set[str] = set() + if isinstance(import_node, cst.Import): + if isinstance(import_node.names, cst.ImportStar): + return {"*"} + for alias in import_node.names: + if isinstance(alias, cst.ImportAlias): + if alias.asname and isinstance(alias.asname.name, cst.Name): + names.add(alias.asname.name.value) + elif isinstance(alias.name, cst.Name): + names.add(alias.name.value) + elif isinstance(alias.name, cst.Attribute): + # import foo.bar -> accessible as "foo" + base: cst.BaseExpression = alias.name + while isinstance(base, cst.Attribute): + base = base.value + if isinstance(base, cst.Name): + names.add(base.value) + elif isinstance(import_node, cst.ImportFrom): + if isinstance(import_node.names, cst.ImportStar): + return {"*"} + for alias in import_node.names: + if isinstance(alias, cst.ImportAlias): + if alias.asname and isinstance(alias.asname.name, cst.Name): + names.add(alias.asname.name.value) + elif isinstance(alias.name, cst.Name): + names.add(alias.name.value) + return names def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode: @@ -693,12 +856,22 @@ def parse_code_and_prune_cst( if code_context_type == CodeContextType.READ_WRITABLE: filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages) elif code_context_type == CodeContextType.READ_ONLY: - filtered_node, found_target = prune_cst_for_read_only_code( - module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings + filtered_node, found_target = prune_cst_for_context( + module, + target_functions, + helpers_of_helper_functions, + remove_docstrings=remove_docstrings, + include_target_in_output=False, + include_init_dunder=False, ) elif code_context_type == CodeContextType.TESTGEN: - filtered_node, found_target = prune_cst_for_testgen_code( - module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings + filtered_node, found_target = prune_cst_for_context( + module, + target_functions, + helpers_of_helper_functions, + remove_docstrings=remove_docstrings, + include_target_in_output=True, + include_init_dunder=True, ) elif code_context_type == CodeContextType.HASHING: filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions) @@ -740,10 +913,29 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911 # Do not recurse into nested classes if prefix: return None, False + + class_name = node.name.value + # Assuming always an IndentedBlock if not isinstance(node.body, cst.IndentedBlock): raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 - class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value + class_prefix = f"{prefix}.{class_name}" if prefix else class_name + + # Check if this class contains any target functions + has_target_functions = any( + isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions + for stmt in node.body.body + ) + + # If the class is used as a dependency (not containing target functions), keep it entirely + # This handles cases like enums, dataclasses, and other types used by the target function + if ( + not has_target_functions + and class_name in defs_with_usages + and defs_with_usages[class_name].used_by_qualified_function + ): + return node, True + new_body = [] found_target = False @@ -903,17 +1095,29 @@ def prune_cst_for_code_hashing( # noqa: PLR0911 return (node.with_changes(**updates) if updates else node), True -def prune_cst_for_read_only_code( # noqa: PLR0911 +def prune_cst_for_context( # noqa: PLR0911 node: cst.CSTNode, target_functions: set[str], helpers_of_helper_functions: set[str], prefix: str = "", remove_docstrings: bool = False, # noqa: FBT001, FBT002 + include_target_in_output: bool = False, # noqa: FBT001, FBT002 + include_init_dunder: bool = False, # noqa: FBT001, FBT002 ) -> tuple[cst.CSTNode | None, bool]: - """Recursively filter the node for read-only context. + """Recursively filter the node for code context extraction. - Returns - ------- + Args: + node: The CST node to filter + target_functions: Set of qualified function names that are targets + helpers_of_helper_functions: Set of helper function qualified names + prefix: Current qualified name prefix (for class methods) + remove_docstrings: Whether to remove docstrings from output + include_target_in_output: If True, include target functions in output (testgen mode) + If False, exclude target functions (read-only mode) + include_init_dunder: If True, include __init__ in dunder methods (testgen mode) + If False, exclude __init__ from dunder methods (read-only mode) + + Returns: (filtered_node, found_target): filtered_node: The modified CST node or None if it should be removed. found_target: True if a target function was found in this node's subtree. @@ -924,17 +1128,28 @@ def prune_cst_for_read_only_code( # noqa: PLR0911 if isinstance(node, cst.FunctionDef): qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value - # If it's a target function, remove it but mark found_target = True + + # Check if it's a helper of helper function if qualified_name in helpers_of_helper_functions: + if remove_docstrings and isinstance(node.body, cst.IndentedBlock): + return node.with_changes(body=remove_docstring_from_body(node.body)), True return node, True + + # Check if it's a target function if qualified_name in target_functions: + if include_target_in_output: + if remove_docstrings and isinstance(node.body, cst.IndentedBlock): + return node.with_changes(body=remove_docstring_from_body(node.body)), True + return node, True return None, True - # Keep only dunder methods - if is_dunder_method(node.name.value) and node.name.value != "__init__": + + # Check dunder methods + # For read-only mode, exclude __init__; for testgen mode, include all dunders + if is_dunder_method(node.name.value) and (include_init_dunder or node.name.value != "__init__"): if remove_docstrings and isinstance(node.body, cst.IndentedBlock): - new_body = remove_docstring_from_body(node.body) - return node.with_changes(body=new_body), False + return node.with_changes(body=remove_docstring_from_body(node.body)), False return node, False + return None, False if isinstance(node, cst.ClassDef): @@ -951,114 +1166,14 @@ def prune_cst_for_read_only_code( # noqa: PLR0911 found_in_class = False new_class_body: list[CSTNode] = [] for stmt in node.body.body: - filtered, found_target = prune_cst_for_read_only_code( - stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings - ) - found_in_class |= found_target - if filtered: - new_class_body.append(filtered) - - if not found_in_class: - return None, False - - if remove_docstrings: - return node.with_changes( - body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)) - ) if new_class_body else None, True - return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True - - # For other nodes, keep the node and recursively filter children - section_names = get_section_names(node) - if not section_names: - return node, False - - updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} - found_any_target = False - - for section in section_names: - original_content = getattr(node, section, None) - if isinstance(original_content, (list, tuple)): - new_children = [] - section_found_target = False - for child in original_content: - filtered, found_target = prune_cst_for_read_only_code( - child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings - ) - if filtered: - new_children.append(filtered) - section_found_target |= found_target - - if section_found_target or new_children: - found_any_target |= section_found_target - updates[section] = new_children - elif original_content is not None: - filtered, found_target = prune_cst_for_read_only_code( - original_content, + filtered, found_target = prune_cst_for_context( + stmt, target_functions, helpers_of_helper_functions, - prefix, + class_prefix, remove_docstrings=remove_docstrings, - ) - found_any_target |= found_target - if filtered: - updates[section] = filtered - if updates: - return (node.with_changes(**updates), found_any_target) - - return None, False - - -def prune_cst_for_testgen_code( # noqa: PLR0911 - node: cst.CSTNode, - target_functions: set[str], - helpers_of_helper_functions: set[str], - prefix: str = "", - remove_docstrings: bool = False, # noqa: FBT001, FBT002 -) -> tuple[cst.CSTNode | None, bool]: - """Recursively filter the node for testgen context. - - Returns - ------- - (filtered_node, found_target): - filtered_node: The modified CST node or None if it should be removed. - found_target: True if a target function was found in this node's subtree. - - """ - if isinstance(node, (cst.Import, cst.ImportFrom)): - return None, False - - if isinstance(node, cst.FunctionDef): - qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value - # If it's a target function, remove it but mark found_target = True - if qualified_name in helpers_of_helper_functions or qualified_name in target_functions: - if remove_docstrings and isinstance(node.body, cst.IndentedBlock): - new_body = remove_docstring_from_body(node.body) - return node.with_changes(body=new_body), True - return node, True - # Keep all dunder methods - if is_dunder_method(node.name.value): - if remove_docstrings and isinstance(node.body, cst.IndentedBlock): - new_body = remove_docstring_from_body(node.body) - return node.with_changes(body=new_body), False - return node, False - return None, False - - if isinstance(node, cst.ClassDef): - # Do not recurse into nested classes - if prefix: - return None, False - # Assuming always an IndentedBlock - if not isinstance(node.body, cst.IndentedBlock): - raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 - - class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value - - # First pass: detect if there is a target function in the class - found_in_class = False - new_class_body: list[CSTNode] = [] - for stmt in node.body.body: - filtered, found_target = prune_cst_for_testgen_code( - stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings + include_target_in_output=include_target_in_output, + include_init_dunder=include_init_dunder, ) found_in_class |= found_target if filtered: @@ -1087,8 +1202,14 @@ def prune_cst_for_testgen_code( # noqa: PLR0911 new_children = [] section_found_target = False for child in original_content: - filtered, found_target = prune_cst_for_testgen_code( - child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings + filtered, found_target = prune_cst_for_context( + child, + target_functions, + helpers_of_helper_functions, + prefix, + remove_docstrings=remove_docstrings, + include_target_in_output=include_target_in_output, + include_init_dunder=include_init_dunder, ) if filtered: new_children.append(filtered) @@ -1098,16 +1219,19 @@ def prune_cst_for_testgen_code( # noqa: PLR0911 found_any_target |= section_found_target updates[section] = new_children elif original_content is not None: - filtered, found_target = prune_cst_for_testgen_code( + filtered, found_target = prune_cst_for_context( original_content, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings, + include_target_in_output=include_target_in_output, + include_init_dunder=include_init_dunder, ) found_any_target |= found_target if filtered: updates[section] = filtered + if updates: return (node.with_changes(**updates), found_any_target) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 823cb735b..107cfe0a7 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -295,11 +295,18 @@ def visit_Name(self, node: cst.Name) -> None: return if name in self.definitions and name != self.current_top_level_name: - # skip if we are refrencing a class attribute and not a top-level definition + # Skip if this Name is the .attr part of an Attribute (e.g., 'x' in 'self.x') + # We only want to track the base/value of attribute access, not the attribute name itself if self.class_depth > 0: parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) if parent is not None and isinstance(parent, cst.Attribute): - return + # Check if this Name is the .attr (property name), not the .value (base) + # If it's the .attr, skip it - attribute names aren't references to definitions + if parent.attr is node: + return + # If it's the .value (base), only skip if it's self/cls + if name in ("self", "cls"): + return self.definitions[self.current_top_level_name].dependencies.add(name) @@ -553,16 +560,6 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na return code -def print_definitions(definitions: dict[str, UsageInfo]) -> None: - """Print information about each definition without the complex node object, used for debugging.""" - print(f"Found {len(definitions)} definitions:") - for name, info in sorted(definitions.items()): - print(f" - Name: {name}") - print(f" Used by qualified function: {info.used_by_qualified_function}") - print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}") - print() - - def revert_unused_helper_functions( project_root: Path, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str] ) -> None: @@ -637,43 +634,40 @@ def _analyze_imports_in_optimized_code( func_name = helper.only_function_name module_name = helper.file_path.stem # Cache function lookup for this (module, func) - file_entry = helpers_by_file_and_func[module_name] - if func_name in file_entry: - file_entry[func_name].append(helper) - else: - file_entry[func_name] = [helper] + helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper) helpers_by_file[module_name].append(helper) - # Optimize attribute lookups and method binding outside the loop - helpers_by_file_and_func_get = helpers_by_file_and_func.get - helpers_by_file_get = helpers_by_file.get - for node in ast.walk(optimized_ast): if isinstance(node, ast.ImportFrom): # Handle "from module import function" statements module_name = node.module if module_name: - file_entry = helpers_by_file_and_func_get(module_name, None) + file_entry = helpers_by_file_and_func.get(module_name) if file_entry: for alias in node.names: imported_name = alias.asname if alias.asname else alias.name original_name = alias.name - helpers = file_entry.get(original_name, None) + helpers = file_entry.get(original_name) if helpers: + imported_set = imported_names_map[imported_name] for helper in helpers: - imported_names_map[imported_name].add(helper.qualified_name) - imported_names_map[imported_name].add(helper.fully_qualified_name) + imported_set.add(helper.qualified_name) + imported_set.add(helper.fully_qualified_name) elif isinstance(node, ast.Import): # Handle "import module" statements for alias in node.names: imported_name = alias.asname if alias.asname else alias.name module_name = alias.name - for helper in helpers_by_file_get(module_name, []): - # For "import module" statements, functions would be called as module.function - full_call = f"{imported_name}.{helper.only_function_name}" - imported_names_map[full_call].add(helper.qualified_name) - imported_names_map[full_call].add(helper.fully_qualified_name) + helpers = helpers_by_file.get(module_name) + if helpers: + imported_set = imported_names_map[f"{imported_name}.{{func}}"] + for helper in helpers: + # For "import module" statements, functions would be called as module.function + full_call = f"{imported_name}.{helper.only_function_name}" + full_call_set = imported_names_map[full_call] + full_call_set.add(helper.qualified_name) + full_call_set.add(helper.fully_qualified_name) return dict(imported_names_map) @@ -753,27 +747,31 @@ def detect_unused_helper_functions( called_name = node.func.id called_function_names.add(called_name) # Also add the qualified name if this is an imported function - if called_name in imported_names_map: - called_function_names.update(imported_names_map[called_name]) + mapped_names = imported_names_map.get(called_name) + if mapped_names: + called_function_names.update(mapped_names) elif isinstance(node.func, ast.Attribute): # Method call: obj.method() or self.method() or module.function() if isinstance(node.func.value, ast.Name): - if node.func.value.id == "self": + attr_name = node.func.attr + value_id = node.func.value.id + if value_id == "self": # self.method_name() -> add both method_name and ClassName.method_name - called_function_names.add(node.func.attr) + called_function_names.add(attr_name) + # For class methods, also add the qualified name # For class methods, also add the qualified name if hasattr(function_to_optimize, "parents") and function_to_optimize.parents: class_name = function_to_optimize.parents[0].name - called_function_names.add(f"{class_name}.{node.func.attr}") + called_function_names.add(f"{class_name}.{attr_name}") else: - # obj.method() or module.function() - attr_name = node.func.attr called_function_names.add(attr_name) - called_function_names.add(f"{node.func.value.id}.{attr_name}") + full_call = f"{value_id}.{attr_name}" + called_function_names.add(full_call) # Check if this is a module.function call that maps to a helper - full_call = f"{node.func.value.id}.{attr_name}" - if full_call in imported_names_map: - called_function_names.update(imported_names_map[full_call]) + mapped_names = imported_names_map.get(full_call) + if mapped_names: + called_function_names.update(mapped_names) + # Handle nested attribute access like obj.attr.method() # Handle nested attribute access like obj.attr.method() else: called_function_names.add(node.func.attr) @@ -783,6 +781,7 @@ def detect_unused_helper_functions( # Find helper functions that are no longer called unused_helpers = [] + entrypoint_file_path = function_to_optimize.file_path for helper_function in code_context.helper_functions: if helper_function.jedi_definition.type != "class": # Check if the helper function is called using multiple name variants @@ -790,29 +789,30 @@ def detect_unused_helper_functions( helper_simple_name = helper_function.only_function_name helper_fully_qualified_name = helper_function.fully_qualified_name - # Create a set of all possible names this helper might be called by - possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name} - + # Check membership efficiently - exit early on first match + if ( + helper_qualified_name in called_function_names + or helper_simple_name in called_function_names + or helper_fully_qualified_name in called_function_names + ): + is_called = True # For cross-file helpers, also consider module-based calls - if helper_function.file_path != function_to_optimize.file_path: + elif helper_function.file_path != entrypoint_file_path: # Add potential module.function combinations module_name = helper_function.file_path.stem - possible_call_names.add(f"{module_name}.{helper_simple_name}") - - # Check if any of the possible names are in the called functions - is_called = bool(possible_call_names.intersection(called_function_names)) + module_call = f"{module_name}.{helper_simple_name}" + is_called = module_call in called_function_names + else: + is_called = False if not is_called: unused_helpers.append(helper_function) logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code") - logger.debug(f" Checked names: {possible_call_names}") else: logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code") - logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}") - - ret_val = unused_helpers except Exception as e: logger.debug(f"Error detecting unused helper functions: {e}") - ret_val = [] - return ret_val + return [] + else: + return unused_helpers diff --git a/codeflash/verification/codeflash_capture.py b/codeflash/verification/codeflash_capture.py index 991f4d624..5c2bf4b6f 100644 --- a/codeflash/verification/codeflash_capture.py +++ b/codeflash/verification/codeflash_capture.py @@ -15,6 +15,8 @@ import dill as pickle from dill import PicklingWarning +from codeflash.picklepatch.pickle_patcher import PicklePatcher + warnings.filterwarnings("ignore", category=PicklingWarning) @@ -148,18 +150,29 @@ def wrapper(*args, **kwargs) -> None: # noqa: ANN002, ANN003 print(f"!######{test_stdout_tag}######!") # Capture instance state after initialization - if hasattr(args[0], "__dict__"): - instance_state = args[ - 0 - ].__dict__ # self is always the first argument, this is ensured during instrumentation + # self is always the first argument, this is ensured during instrumentation + instance = args[0] + if hasattr(instance, "__dict__"): + instance_state = instance.__dict__ + elif hasattr(instance, "__slots__"): + # For classes using __slots__, capture slot values + instance_state = { + slot: getattr(instance, slot, None) for slot in instance.__slots__ if hasattr(instance, slot) + } else: - raise ValueError("Instance state could not be captured.") + # For C extension types or other special classes (e.g., Playwright's Page), + # capture all non-private, non-callable attributes + instance_state = { + attr: getattr(instance, attr) + for attr in dir(instance) + if not attr.startswith("_") and not callable(getattr(instance, attr, None)) + } codeflash_cur.execute( "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)" ) # Write to sqlite - pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(instance_state) + pickled_return_value = pickle.dumps(exception) if exception else PicklePatcher.dumps(instance_state) codeflash_cur.execute( "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", ( diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 03015ab24..0705d2581 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -19,6 +19,14 @@ test_diff_repr = reprlib_repr.repr +def safe_repr(obj: object) -> str: + """Safely get repr of an object, handling Mock objects with corrupted state.""" + try: + return repr(obj) + except (AttributeError, TypeError, RecursionError) as e: + return f"" + + def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: # This is meant to be only called with test results for the first loop index if len(original_results) == 0 or len(candidate_results) == 0: @@ -77,8 +85,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR test_diffs.append( TestDiff( scope=TestDiffScope.RETURN_VALUE, - original_value=test_diff_repr(repr(original_test_result.return_value)), - candidate_value=test_diff_repr(repr(cdd_test_result.return_value)), + original_value=test_diff_repr(safe_repr(original_test_result.return_value)), + candidate_value=test_diff_repr(safe_repr(cdd_test_result.return_value)), test_src_code=original_test_result.id.get_src_code(original_test_result.file_name), candidate_pytest_error=cdd_pytest_error, original_pass=original_test_result.did_pass, diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 0593c37bc..71db216e4 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -7,13 +7,19 @@ from pathlib import Path import pytest -from codeflash.context.code_context_extractor import get_code_optimization_context, get_imported_class_definitions -from codeflash.models.models import CodeString, CodeStringsMarkdown + +from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments +from codeflash.code_utils.code_replacer import replace_functions_and_add_imports +from codeflash.context.code_context_extractor import ( + collect_names_from_annotation, + extract_imports_for_class, + get_code_optimization_context, + get_external_base_class_inits, + get_imported_class_definitions, +) from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent +from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent from codeflash.optimization.optimizer import Optimizer -from codeflash.code_utils.code_replacer import replace_functions_and_add_imports -from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector class HelperClass: @@ -86,7 +92,10 @@ def test_code_replacement10() -> None: code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent) qualified_names = {func.qualified_name for func in code_ctx.helper_functions} # HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class - assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here + assert qualified_names == { + "HelperClass.helper_method", + "HelperClass.__init__", + } # Nested method should not be in here read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context @@ -229,7 +238,7 @@ def test_bubble_sort_helper() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = f""" + expected_read_write_context = """ ```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py import math @@ -1103,7 +1112,9 @@ def helper_method(self): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) # the global x variable shouldn't be included in any context type - assert code_ctx.read_writable_code.flat == '''# file: test_code.py + assert ( + code_ctx.read_writable_code.flat + == '''# file: test_code.py class MyClass: def __init__(self): self.x = 1 @@ -1118,7 +1129,10 @@ def __init__(self): def helper_method(self): return self.x ''' - assert code_ctx.testgen_context.flat == '''# file: test_code.py + ) + assert ( + code_ctx.testgen_context.flat + == '''# file: test_code.py class MyClass: """A class with a helper method. """ def __init__(self): @@ -1138,6 +1152,7 @@ def __repr__(self): def helper_method(self): return self.x ''' + ) def test_repo_helper() -> None: @@ -2348,9 +2363,7 @@ def standalone_function(): assert '"""Helper method with docstring."""' not in hashing_context, ( "Docstrings should be removed from helper functions" ) - assert '"""Process data method."""' not in hashing_context, ( - "Docstrings should be removed from helper class methods" - ) + assert '"""Process data method."""' not in hashing_context, "Docstrings should be removed from helper class methods" def test_hashing_code_context_with_nested_classes(tmp_path: Path) -> None: @@ -2588,16 +2601,21 @@ def test_circular_deps(): optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8") content = Path(file_abs_path).read_text(encoding="utf-8") new_code = replace_functions_and_add_imports( - source_code= add_global_assignments(optimized_code, content), - function_names= ["ApiClient.get_console_url"], - optimized_code= optimized_code, - module_abspath= Path(file_abs_path), - preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))}, - project_root_path= Path(path_to_root), + source_code=add_global_assignments(optimized_code, content), + function_names=["ApiClient.get_console_url"], + optimized_code=optimized_code, + module_abspath=Path(file_abs_path), + preexisting_objects={ + ("ApiClient", ()), + ("get_console_url", (FunctionParent(name="ApiClient", type="ClassDef"),)), + }, + project_root_path=Path(path_to_root), ) assert "import ApiClient" not in new_code, "Error: Circular dependency found" assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist" + + def test_global_assignment_collector_with_async_function(): """Test GlobalAssignmentCollector correctly identifies global assignments outside async functions.""" import libcst as cst @@ -2745,6 +2763,380 @@ async def async_function(): assert collector.assignment_order == expected_order +def test_global_assignment_collector_annotated_assignments(): + """Test GlobalAssignmentCollector correctly handles annotated assignments (AnnAssign).""" + import libcst as cst + + source_code = """ +# Regular global assignment +REGULAR_VAR = "regular" + +# Annotated global assignments +TYPED_VAR: str = "typed" +CACHE: dict[str, int] = {} +SENTINEL: object = object() + +# Annotated without value (type declaration only) - should NOT be collected +DECLARED_ONLY: int + +def some_function(): + # Annotated assignment inside function - should not be collected + local_typed: str = "local" + return local_typed + +class SomeClass: + # Class-level annotated assignment - should not be collected + class_attr: str = "class" + +# Another regular assignment +FINAL_VAR = 123 +""" + + tree = cst.parse_module(source_code) + collector = GlobalAssignmentCollector() + tree.visit(collector) + + # Should collect both regular and annotated global assignments with values + assert len(collector.assignments) == 5 + assert "REGULAR_VAR" in collector.assignments + assert "TYPED_VAR" in collector.assignments + assert "CACHE" in collector.assignments + assert "SENTINEL" in collector.assignments + assert "FINAL_VAR" in collector.assignments + + # Should not collect type declarations without values + assert "DECLARED_ONLY" not in collector.assignments + + # Should not collect assignments from inside functions or classes + assert "local_typed" not in collector.assignments + assert "class_attr" not in collector.assignments + + # Verify correct order + expected_order = ["REGULAR_VAR", "TYPED_VAR", "CACHE", "SENTINEL", "FINAL_VAR"] + assert collector.assignment_order == expected_order + + +def test_global_function_collector(): + """Test GlobalFunctionCollector correctly collects module-level function definitions.""" + import libcst as cst + + from codeflash.code_utils.code_extractor import GlobalFunctionCollector + + source_code = """ +# Module-level functions +def helper_function(): + return "helper" + +def another_helper(x: int) -> str: + return str(x) + +class SomeClass: + def method(self): + # This is a method, not a module-level function + return "method" + + def another_method(self): + # Also a method + def nested_function(): + # Nested function inside method + return "nested" + return nested_function() + +def final_function(): + def inner_function(): + # This is a nested function, not module-level + return "inner" + return inner_function() +""" + + tree = cst.parse_module(source_code) + collector = GlobalFunctionCollector() + tree.visit(collector) + + # Should collect only module-level functions + assert len(collector.functions) == 3 + assert "helper_function" in collector.functions + assert "another_helper" in collector.functions + assert "final_function" in collector.functions + + # Should not collect methods or nested functions + assert "method" not in collector.functions + assert "another_method" not in collector.functions + assert "nested_function" not in collector.functions + assert "inner_function" not in collector.functions + + # Verify correct order + expected_order = ["helper_function", "another_helper", "final_function"] + assert collector.function_order == expected_order + + +def test_add_global_assignments_with_new_functions(): + """Test add_global_assignments correctly adds new module-level functions.""" + source_code = """\ +from functools import lru_cache + +class SkyvernPage: + @staticmethod + def action_wrap(action): + return _get_decorator_for_action(action) + +@lru_cache(maxsize=None) +def _get_decorator_for_action(action): + def decorator(fn): + return fn + return decorator +""" + + destination_code = """\ +from functools import lru_cache + +class SkyvernPage: + @staticmethod + def action_wrap(action): + # Original implementation + return action +""" + + expected = """\ +from functools import lru_cache + +class SkyvernPage: + @staticmethod + def action_wrap(action): + # Original implementation + return action + + +@lru_cache(maxsize=None) +def _get_decorator_for_action(action): + def decorator(fn): + return fn + return decorator +""" + + result = add_global_assignments(source_code, destination_code) + assert result == expected + + +def test_add_global_assignments_does_not_duplicate_existing_functions(): + """Test add_global_assignments does not duplicate functions that already exist in destination.""" + source_code = """\ +def helper(): + return "source_helper" + +def existing_function(): + return "source_existing" +""" + + destination_code = """\ +def existing_function(): + return "dest_existing" + +class MyClass: + pass +""" + + expected = """\ +def existing_function(): + return "dest_existing" + +class MyClass: + pass + +def helper(): + return "source_helper" +""" + + result = add_global_assignments(source_code, destination_code) + assert result == expected + + +def test_add_global_assignments_with_decorated_functions(): + """Test add_global_assignments correctly adds decorated functions.""" + source_code = """\ +from functools import lru_cache +from typing import Callable + +_LOCAL_CACHE: dict[str, int] = {} + +@lru_cache(maxsize=128) +def cached_helper(x: int) -> int: + return x * 2 + +def regular_helper(): + return "regular" +""" + + destination_code = """\ +from typing import Any + +class MyClass: + def method(self): + return cached_helper(5) +""" + + expected = """\ +from typing import Any + +_LOCAL_CACHE: dict[str, int] = {} + +class MyClass: + def method(self): + return cached_helper(5) + + +@lru_cache(maxsize=128) +def cached_helper(x: int) -> int: + return x * 2 + + +def regular_helper(): + return "regular" +""" + + result = add_global_assignments(source_code, destination_code) + assert result == expected + + +def test_add_global_assignments_references_class_defined_in_module(): + """Test that global assignments referencing classes are placed after those class definitions. + + This test verifies the fix for a bug where LLM-generated optimization code like: + _REIFIERS = {MessageKind.XXX: lambda d: ...} + was placed BEFORE the MessageKind class definition, causing NameError at module load. + + The fix ensures that new global assignments are inserted AFTER all class/function + definitions in the module, so they can safely reference any class defined in the module. + """ + source_code = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +_MESSAGE_HANDLERS = { + MessageKind.ASK: lambda: "ask handler", + MessageKind.REPLY: lambda: "reply handler", +} + +def handle_message(kind): + return _MESSAGE_HANDLERS[kind]() +""" + + destination_code = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +def handle_message(kind): + if kind == MessageKind.ASK: + return "ask" + return "reply" +""" + + # Global assignments are now inserted AFTER class/function definitions + # to ensure they can reference classes defined in the module + expected = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +def handle_message(kind): + if kind == MessageKind.ASK: + return "ask" + return "reply" + +_MESSAGE_HANDLERS = { + MessageKind.ASK: lambda: "ask handler", + MessageKind.REPLY: lambda: "reply handler", +} +""" + + result = add_global_assignments(source_code, destination_code) + assert result == expected + + +def test_add_global_assignments_function_calls_after_function_definitions(): + """Test that global function calls are placed after the functions they reference. + + This test verifies the fix for a bug where LLM-generated optimization code like: + def _register(kind, factory): + _factories[kind] = factory + + _register(MessageKind.ASK, lambda: "ask") + + would have the _register(...) calls placed BEFORE the _register function definition, + causing NameError at module load time. + + The fix ensures that new global statements (like function calls) are inserted AFTER + all class/function definitions, so they can safely reference any function defined in + the module. + """ + source_code = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +_factories = {} + +def _register(kind, factory): + _factories[kind] = factory + +_register(MessageKind.ASK, lambda: "ask handler") +_register(MessageKind.REPLY, lambda: "reply handler") + +def handle_message(kind): + return _factories[kind]() +""" + + destination_code = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +def handle_message(kind): + if kind == MessageKind.ASK: + return "ask" + return "reply" +""" + + expected = """\ +import enum + +_factories = {} + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +def handle_message(kind): + if kind == MessageKind.ASK: + return "ask" + return "reply" + + +def _register(kind, factory): + _factories[kind] = factory + + +_register(MessageKind.ASK, lambda: "ask handler") + +_register(MessageKind.REPLY, lambda: "reply handler") +""" + + result = add_global_assignments(source_code, destination_code) + assert result == expected + + def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None: """Test that when a class is instantiated, its __init__ method is tracked as a helper. @@ -2785,11 +3177,7 @@ def target_function(): ) ) function_to_optimize = FunctionToOptimize( - function_name="target_function", - file_path=file_path, - parents=[], - starting_line=None, - ending_line=None, + function_name="target_function", file_path=file_path, parents=[], starting_line=None, ending_line=None ) code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) @@ -2803,15 +3191,11 @@ def target_function(): # The testgen context should contain the class with __init__ (critical for LLM to know constructor) testgen_context = code_ctx.testgen_context.markdown assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context" - assert "def __init__(self, data):" in testgen_context, ( - "__init__ method should be included in testgen context" - ) + assert "def __init__(self, data):" in testgen_context, "__init__ method should be included in testgen context" # The hashing context should NOT contain __init__ (excluded for stability) hashing_context = code_ctx.hashing_code_context - assert "__init__" not in hashing_context, ( - "__init__ should NOT be in hashing context (excluded for hash stability)" - ) + assert "__init__" not in hashing_context, "__init__ should NOT be in hashing context (excluded for hash stability)" def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None: @@ -2865,11 +3249,7 @@ def dump_layout(layout_type, layout): ) ) function_to_optimize = FunctionToOptimize( - function_name="dump_layout", - file_path=file_path, - parents=[], - starting_line=None, - ending_line=None, + function_name="dump_layout", file_path=file_path, parents=[], starting_line=None, ending_line=None ) code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) @@ -2879,9 +3259,7 @@ def dump_layout(layout_type, layout): assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, ( "ObjectDetectionLayoutDumper.__init__ should be tracked" ) - assert "LayoutDumper.__init__" in qualified_names, ( - "LayoutDumper.__init__ should be tracked" - ) + assert "LayoutDumper.__init__" in qualified_names, "LayoutDumper.__init__ should be tracked" # The testgen context should include both classes with their __init__ methods testgen_context = code_ctx.testgen_context.markdown @@ -2891,9 +3269,7 @@ def dump_layout(layout_type, layout): ) # Both __init__ methods should be in the testgen context (so LLM knows constructor signatures) - assert testgen_context.count("def __init__") >= 2, ( - "Both __init__ methods should be in testgen context" - ) + assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context" def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None: @@ -2929,7 +3305,7 @@ def __init__(self, text: str, element_id: str = None): elements_path.write_text(elements_code, encoding="utf-8") # Create another module that imports from elements - chunking_code = ''' + chunking_code = """ from mypackage.elements import Element class PreChunk: @@ -2939,14 +3315,12 @@ def __init__(self, elements: list[Element]): class Accumulator: def will_fit(self, chunk: PreChunk) -> bool: return True -''' +""" chunking_path = package_dir / "chunking.py" chunking_path.write_text(chunking_code, encoding="utf-8") # Create CodeStringsMarkdown from the chunking module (simulating testgen context) - context = CodeStringsMarkdown( - code_strings=[CodeString(code=chunking_code, file_path=chunking_path)] - ) + context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)]) # Call get_imported_class_definitions result = get_imported_class_definitions(context, tmp_path) @@ -2970,16 +3344,16 @@ def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Pat (package_dir / "__init__.py").write_text("", encoding="utf-8") # Create a module with a class definition - elements_code = ''' + elements_code = """ class Element: def __init__(self, text: str): self.text = text -''' +""" elements_path = package_dir / "elements.py" elements_path.write_text(elements_code, encoding="utf-8") # Create code that imports Element but also redefines it locally - code_with_local_def = ''' + code_with_local_def = """ from mypackage.elements import Element # Local redefinition (this happens when LLM redefines classes) @@ -2990,13 +3364,11 @@ def __init__(self, text: str): class User: def process(self, elem: Element): pass -''' +""" code_path = package_dir / "user.py" code_path.write_text(code_with_local_def, encoding="utf-8") - context = CodeStringsMarkdown( - code_strings=[CodeString(code=code_with_local_def, file_path=code_path)] - ) + context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)]) # Call get_imported_class_definitions result = get_imported_class_definitions(context, tmp_path) @@ -3013,7 +3385,7 @@ def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> Non (package_dir / "__init__.py").write_text("", encoding="utf-8") # Code with stdlib/third-party imports - code = ''' + code = """ from pathlib import Path from typing import Optional from dataclasses import dataclass @@ -3021,13 +3393,11 @@ def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> Non class MyClass: def __init__(self, path: Path): self.path = path -''' +""" code_path = package_dir / "main.py" code_path.write_text(code, encoding="utf-8") - context = CodeStringsMarkdown( - code_strings=[CodeString(code=code, file_path=code_path)] - ) + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) # Call get_imported_class_definitions result = get_imported_class_definitions(context, tmp_path) @@ -3044,7 +3414,7 @@ def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path) (package_dir / "__init__.py").write_text("", encoding="utf-8") # Create a module with multiple class definitions - types_code = ''' + types_code = """ class TypeA: def __init__(self, value: int): self.value = value @@ -3056,24 +3426,22 @@ def __init__(self, name: str): class TypeC: def __init__(self): pass -''' +""" types_path = package_dir / "types.py" types_path.write_text(types_code, encoding="utf-8") # Create code that imports multiple classes - code = ''' + code = """ from mypackage.types import TypeA, TypeB class Processor: def process(self, a: TypeA, b: TypeB): pass -''' +""" code_path = package_dir / "processor.py" code_path.write_text(code, encoding="utf-8") - context = CodeStringsMarkdown( - code_strings=[CodeString(code=code, file_path=code_path)] - ) + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) # Call get_imported_class_definitions result = get_imported_class_definitions(context, tmp_path) @@ -3085,3 +3453,1200 @@ def process(self, a: TypeA, b: TypeB): assert "class TypeA" in all_extracted_code, "Should contain TypeA class" assert "class TypeB" in all_extracted_code, "Should contain TypeB class" assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)" + + +def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path: Path) -> None: + """Test that get_imported_class_definitions includes decorators when extracting dataclasses.""" + # Create a package structure + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + # Create a module with dataclass definitions (like LLMConfig in skyvern) + models_code = """from dataclasses import dataclass, field +from typing import Optional + +@dataclass(frozen=True) +class LLMConfigBase: + model_name: str + required_env_vars: list[str] + supports_vision: bool + add_assistant_prefix: bool + +@dataclass(frozen=True) +class LLMConfig(LLMConfigBase): + litellm_params: Optional[dict] = field(default=None) + max_tokens: int | None = None +""" + models_path = package_dir / "models.py" + models_path.write_text(models_code, encoding="utf-8") + + # Create code that imports the dataclass + code = """from mypackage.models import LLMConfig + +class ConfigRegistry: + def get_config(self) -> LLMConfig: + pass +""" + code_path = package_dir / "registry.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + + # Call get_imported_class_definitions + result = get_imported_class_definitions(context, tmp_path) + + # Should extract both LLMConfigBase (base class) and LLMConfig + assert len(result.code_strings) == 2, "Should extract both LLMConfig and its base class LLMConfigBase" + + # Combine extracted code to check for all required elements + all_extracted_code = "\n".join(cs.code for cs in result.code_strings) + + # Verify the base class is extracted first (for proper inheritance understanding) + base_class_idx = all_extracted_code.find("class LLMConfigBase") + derived_class_idx = all_extracted_code.find("class LLMConfig(") + assert base_class_idx < derived_class_idx, "Base class should appear before derived class" + + # Verify both classes include @dataclass decorators + assert all_extracted_code.count("@dataclass(frozen=True)") == 2, ( + "Should include @dataclass decorator for both classes" + ) + assert "class LLMConfig" in all_extracted_code, "Should contain LLMConfig class definition" + assert "class LLMConfigBase" in all_extracted_code, "Should contain LLMConfigBase class definition" + + # Verify imports are included for dataclass-related items + assert "from dataclasses import" in all_extracted_code, "Should include dataclasses import" + + +def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None: + """Test that extract_imports_for_class includes decorator and type annotation imports.""" + # Create a package structure + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + # Create a module with decorated class that uses field() and various type annotations + models_code = """from dataclasses import dataclass, field +from typing import Optional, List + +@dataclass +class Config: + name: str + values: List[int] = field(default_factory=list) + description: Optional[str] = None +""" + models_path = package_dir / "models.py" + models_path.write_text(models_code, encoding="utf-8") + + # Create code that imports the class + code = """from mypackage.models import Config + +def create_config() -> Config: + return Config(name="test") +""" + code_path = package_dir / "main.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + + result = get_imported_class_definitions(context, tmp_path) + + assert len(result.code_strings) == 1, "Should extract Config class" + extracted_code = result.code_strings[0].code + + # The extracted code should include the decorator + assert "@dataclass" in extracted_code, "Should include @dataclass decorator" + # The imports should include dataclass and field + assert "from dataclasses import" in extracted_code, "Should include dataclasses import for decorator" + + +class TestCollectNamesFromAnnotation: + """Tests for the collect_names_from_annotation helper function.""" + + def test_simple_name(self): + """Test extracting a simple type name.""" + import ast + + code = "def f(x: MyClass): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + assert "MyClass" in names + + def test_subscript_type(self): + """Test extracting names from generic types like List[int].""" + import ast + + code = "def f(x: List[int]): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + assert "List" in names + assert "int" in names + + def test_optional_type(self): + """Test extracting names from Optional[MyClass].""" + import ast + + code = "def f(x: Optional[MyClass]): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + assert "Optional" in names + assert "MyClass" in names + + def test_union_type_with_pipe(self): + """Test extracting names from union types with | syntax.""" + import ast + + code = "def f(x: int | str | None): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + # int | str | None becomes BinOp nodes + assert "int" in names + assert "str" in names + + def test_nested_generic_types(self): + """Test extracting names from nested generics like Dict[str, List[MyClass]].""" + import ast + + code = "def f(x: Dict[str, List[MyClass]]): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + assert "Dict" in names + assert "str" in names + assert "List" in names + assert "MyClass" in names + + def test_tuple_annotation(self): + """Test extracting names from tuple type hints.""" + import ast + + code = "def f(x: tuple[int, str, MyClass]): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + assert "tuple" in names + assert "int" in names + assert "str" in names + assert "MyClass" in names + + +class TestExtractImportsForClass: + """Tests for the extract_imports_for_class helper function.""" + + def test_extracts_base_class_imports(self): + """Test that base class imports are extracted.""" + import ast + + module_source = """from abc import ABC +from mypackage import BaseClass + +class MyClass(BaseClass, ABC): + pass +""" + tree = ast.parse(module_source) + class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) + result = extract_imports_for_class(tree, class_node, module_source) + assert "from abc import ABC" in result + assert "from mypackage import BaseClass" in result + + def test_extracts_decorator_imports(self): + """Test that decorator imports are extracted.""" + import ast + + module_source = """from dataclasses import dataclass +from functools import lru_cache + +@dataclass +class MyClass: + name: str +""" + tree = ast.parse(module_source) + class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) + result = extract_imports_for_class(tree, class_node, module_source) + assert "from dataclasses import dataclass" in result + + def test_extracts_type_annotation_imports(self): + """Test that type annotation imports are extracted.""" + import ast + + module_source = """from typing import Optional, List +from mypackage.models import Config + +@dataclass +class MyClass: + config: Optional[Config] + items: List[str] +""" + tree = ast.parse(module_source) + class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) + result = extract_imports_for_class(tree, class_node, module_source) + assert "from typing import Optional, List" in result + assert "from mypackage.models import Config" in result + + def test_extracts_field_function_imports(self): + """Test that field() function imports are extracted for dataclasses.""" + import ast + + module_source = """from dataclasses import dataclass, field +from typing import List + +@dataclass +class MyClass: + items: List[str] = field(default_factory=list) +""" + tree = ast.parse(module_source) + class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) + result = extract_imports_for_class(tree, class_node, module_source) + assert "from dataclasses import dataclass, field" in result + + def test_no_duplicate_imports(self): + """Test that duplicate imports are not included.""" + import ast + + module_source = """from typing import Optional + +@dataclass +class MyClass: + field1: Optional[str] + field2: Optional[int] +""" + tree = ast.parse(module_source) + class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) + result = extract_imports_for_class(tree, class_node, module_source) + # Should only have one import line even though Optional is used twice + assert result.count("from typing import Optional") == 1 + + +def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> None: + """Test that classes with multiple decorators are extracted correctly.""" + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + models_code = """from dataclasses import dataclass +from functools import total_ordering + +@total_ordering +@dataclass +class OrderedConfig: + name: str + priority: int + + def __lt__(self, other): + return self.priority < other.priority +""" + models_path = package_dir / "models.py" + models_path.write_text(models_code, encoding="utf-8") + + code = """from mypackage.models import OrderedConfig + +def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]: + return sorted(configs) +""" + code_path = package_dir / "main.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + + result = get_imported_class_definitions(context, tmp_path) + + assert len(result.code_strings) == 1 + extracted_code = result.code_strings[0].code + + # Both decorators should be included + assert "@total_ordering" in extracted_code, "Should include @total_ordering decorator" + assert "@dataclass" in extracted_code, "Should include @dataclass decorator" + assert "class OrderedConfig" in extracted_code + + +def test_get_imported_class_definitions_extracts_multilevel_inheritance(tmp_path: Path) -> None: + """Test that base classes are recursively extracted for multi-level inheritance. + + This is critical for understanding dataclass constructor signatures, as fields + from parent classes become required positional arguments in child classes. + """ + # Create a package structure + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + # Create a module with multi-level inheritance like skyvern's LLM models: + # GrandParent -> Parent -> Child + models_code = '''from dataclasses import dataclass, field +from typing import Optional, Literal + +@dataclass(frozen=True) +class GrandParentConfig: + """Base config with common fields.""" + model_name: str + required_env_vars: list[str] + +@dataclass(frozen=True) +class ParentConfig(GrandParentConfig): + """Intermediate config adding vision support.""" + supports_vision: bool + add_assistant_prefix: bool + +@dataclass(frozen=True) +class ChildConfig(ParentConfig): + """Full config with optional parameters.""" + litellm_params: Optional[dict] = field(default=None) + max_tokens: int | None = None + temperature: float | None = 0.7 + +@dataclass(frozen=True) +class RouterConfig(ParentConfig): + """Router config branching from ParentConfig.""" + model_list: list + main_model_group: str + routing_strategy: Literal["simple", "least-busy"] = "simple" +''' + models_path = package_dir / "models.py" + models_path.write_text(models_code, encoding="utf-8") + + # Create code that imports only the child classes (not the base classes) + code = """from mypackage.models import ChildConfig, RouterConfig + +class ConfigRegistry: + def get_child_config(self) -> ChildConfig: + pass + + def get_router_config(self) -> RouterConfig: + pass +""" + code_path = package_dir / "registry.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + + # Call get_imported_class_definitions + result = get_imported_class_definitions(context, tmp_path) + + # Should extract 4 classes: GrandParentConfig, ParentConfig, ChildConfig, RouterConfig + # (all classes needed to understand the full inheritance hierarchy) + assert len(result.code_strings) == 4, ( + f"Should extract 4 classes (GrandParent, Parent, Child, Router), got {len(result.code_strings)}" + ) + + # Combine extracted code + all_extracted_code = "\n".join(cs.code for cs in result.code_strings) + + # Verify all classes are extracted + assert "class GrandParentConfig" in all_extracted_code, "Should extract GrandParentConfig base class" + assert "class ParentConfig(GrandParentConfig)" in all_extracted_code, "Should extract ParentConfig" + assert "class ChildConfig(ParentConfig)" in all_extracted_code, "Should extract ChildConfig" + assert "class RouterConfig(ParentConfig)" in all_extracted_code, "Should extract RouterConfig" + + # Verify classes are ordered correctly (base classes before derived) + grandparent_idx = all_extracted_code.find("class GrandParentConfig") + parent_idx = all_extracted_code.find("class ParentConfig(") + child_idx = all_extracted_code.find("class ChildConfig(") + router_idx = all_extracted_code.find("class RouterConfig(") + + assert grandparent_idx < parent_idx, "GrandParentConfig should appear before ParentConfig" + assert parent_idx < child_idx, "ParentConfig should appear before ChildConfig" + assert parent_idx < router_idx, "ParentConfig should appear before RouterConfig" + + # Verify the critical fields are visible for constructor understanding + assert "model_name: str" in all_extracted_code, "Should include model_name field from GrandParent" + assert "required_env_vars: list[str]" in all_extracted_code, "Should include required_env_vars field" + assert "supports_vision: bool" in all_extracted_code, "Should include supports_vision field from Parent" + assert "litellm_params:" in all_extracted_code, "Should include litellm_params field from Child" + assert "model_list: list" in all_extracted_code, "Should include model_list field from Router" + + +def test_get_external_base_class_inits_extracts_userdict(tmp_path: Path) -> None: + """Extracts __init__ from collections.UserDict when a class inherits from it.""" + code = """from collections import UserDict + +class MyCustomDict(UserDict): + pass +""" + code_path = tmp_path / "mydict.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + assert len(result.code_strings) == 1 + code_string = result.code_strings[0] + + expected_code = """\ +class UserDict: + def __init__(self, dict=None, /, **kwargs): + self.data = {} + if dict is not None: + self.update(dict) + if kwargs: + self.update(kwargs) +""" + assert code_string.code == expected_code + assert code_string.file_path.as_posix().endswith("collections/__init__.py") + + +def test_get_external_base_class_inits_skips_project_classes(tmp_path: Path) -> None: + """Returns empty when base class is from the project, not external.""" + child_code = """from base import ProjectBase + +class Child(ProjectBase): + pass +""" + child_path = tmp_path / "child.py" + child_path.write_text(child_code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=child_code, file_path=child_path)]) + result = get_external_base_class_inits(context, tmp_path) + + assert result.code_strings == [] + + +def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None: + """Returns empty for builtin classes like list that have no inspectable source.""" + code = """class MyList(list): + pass +""" + code_path = tmp_path / "mylist.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + assert result.code_strings == [] + + +def test_get_external_base_class_inits_deduplicates(tmp_path: Path) -> None: + """Extracts the same external base class only once even when inherited multiple times.""" + code = """from collections import UserDict + +class MyDict1(UserDict): + pass + +class MyDict2(UserDict): + pass +""" + code_path = tmp_path / "mydicts.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + assert len(result.code_strings) == 1 + expected_code = """\ +class UserDict: + def __init__(self, dict=None, /, **kwargs): + self.data = {} + if dict is not None: + self.update(dict) + if kwargs: + self.update(kwargs) +""" + assert result.code_strings[0].code == expected_code + + +def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) -> None: + """Returns empty when there are no external base classes.""" + code = """class SimpleClass: + pass +""" + code_path = tmp_path / "simple.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + assert result.code_strings == [] + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="enum.StrEnum requires Python 3.11+") +def test_dependency_classes_kept_in_read_writable_context(tmp_path: Path) -> None: + """Tests that classes used as dependencies (enums, dataclasses) are kept in read-writable context. + + This test verifies that when a function uses classes like enums or dataclasses + as types or in match statements, those classes are included in the optimization + context, even though they don't contain any target functions. + """ + code = ''' +import dataclasses +import enum +import typing as t + + +class MessageKind(enum.StrEnum): + ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response" + BEGIN_EXFILTRATION = "begin-exfiltration" + + +@dataclasses.dataclass +class Message: + kind: str + + +@dataclasses.dataclass +class MessageInAskForClipboardResponse(Message): + kind: t.Literal[MessageKind.ASK_FOR_CLIPBOARD_RESPONSE] = MessageKind.ASK_FOR_CLIPBOARD_RESPONSE + text: str = "" + + +@dataclasses.dataclass +class MessageInBeginExfiltration(Message): + kind: t.Literal[MessageKind.BEGIN_EXFILTRATION] = MessageKind.BEGIN_EXFILTRATION + + +MessageIn = ( + MessageInAskForClipboardResponse + | MessageInBeginExfiltration +) + + +def reify_channel_message(data: dict) -> MessageIn: + kind = data.get("kind", None) + + match kind: + case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE: + text = data.get("text") or "" + return MessageInAskForClipboardResponse(text=text) + case MessageKind.BEGIN_EXFILTRATION: + return MessageInBeginExfiltration() + case _: + raise ValueError(f"Unknown message kind: '{kind}'") +''' + code_path = tmp_path / "message.py" + code_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="reify_channel_message", + file_path=code_path, + parents=[], + ) + + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + ) + + expected_read_writable = """ +```python:message.py +import dataclasses +import enum +import typing as t + +class MessageKind(enum.StrEnum): + ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response" + BEGIN_EXFILTRATION = "begin-exfiltration" + + +@dataclasses.dataclass +class Message: + kind: str + + +@dataclasses.dataclass +class MessageInAskForClipboardResponse(Message): + kind: t.Literal[MessageKind.ASK_FOR_CLIPBOARD_RESPONSE] = MessageKind.ASK_FOR_CLIPBOARD_RESPONSE + text: str = "" + + +@dataclasses.dataclass +class MessageInBeginExfiltration(Message): + kind: t.Literal[MessageKind.BEGIN_EXFILTRATION] = MessageKind.BEGIN_EXFILTRATION + + +MessageIn = ( + MessageInAskForClipboardResponse + | MessageInBeginExfiltration +) + + +def reify_channel_message(data: dict) -> MessageIn: + kind = data.get("kind", None) + + match kind: + case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE: + text = data.get("text") or "" + return MessageInAskForClipboardResponse(text=text) + case MessageKind.BEGIN_EXFILTRATION: + return MessageInBeginExfiltration() + case _: + raise ValueError(f"Unknown message kind: '{kind}'") +``` +""" + assert code_ctx.read_writable_code.markdown.strip() == expected_read_writable.strip() + + +def test_testgen_context_includes_external_base_inits(tmp_path: Path) -> None: + """Test that external base class __init__ methods are included in testgen context. + + This covers line 65 in code_context_extractor.py where external_base_inits.code_strings + are appended to the testgen context when a class inherits from an external library. + """ + code = """from collections import UserDict + +class MyCustomDict(UserDict): + def target_method(self): + return self.data +""" + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyCustomDict", type="ClassDef")], + ) + + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + ) + + # The testgen context should include the UserDict __init__ method + testgen_context = code_ctx.testgen_context.markdown + assert "class UserDict:" in testgen_context, "UserDict class should be in testgen context" + assert "def __init__" in testgen_context, "UserDict __init__ should be in testgen context" + assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included" + + +def test_read_only_code_removed_when_exceeds_limit(tmp_path: Path) -> None: + """Test read-only code is completely removed when it exceeds token limit even without docstrings. + + This covers lines 152-153 in code_context_extractor.py where read_only_context_code is set + to empty string when it still exceeds the token limit after docstring removal. + """ + # Create a second-degree helper with large implementation that has no docstrings + # Second-degree helpers go into read-only context + long_lines = [" x = 0"] + for i in range(150): + long_lines.append(f" x = x + {i}") + long_lines.append(" return x") + long_body = "\n".join(long_lines) + + code = f""" +class MyClass: + def __init__(self): + self.x = 1 + + def target_method(self): + return first_helper() + + +def first_helper(): + # First degree helper - calls second degree + return second_helper() + + +def second_helper(): + # Second degree helper - goes into read-only context +{long_body} +""" + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + ) + + # Use a small optim_token_limit that allows read-writable but not read-only + # Read-writable is ~48 tokens, read-only is ~600 tokens + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + optim_token_limit=100, # Small limit to trigger read-only removal + ) + + # The read-only context should be empty because it exceeded the limit + assert code_ctx.read_only_context_code == "", "Read-only code should be removed when exceeding token limit" + + +def test_testgen_removes_imported_classes_on_overflow(tmp_path: Path) -> None: + """Test testgen context removes imported class definitions when exceeding token limit. + + This covers lines 176-186 in code_context_extractor.py where: + - Testgen context exceeds limit (line 175) + - Removing docstrings still exceeds (line 175 again) + - Removing imported classes succeeds (line 177-183) + """ + # Create a package structure with a large type class used only in type annotations + # This ensures get_imported_class_definitions extracts the full class + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + # Create a large class with methods that will be extracted via get_imported_class_definitions + # Use methods WITHOUT docstrings so removing docstrings won't help much + many_methods = "\n".join([f" def method_{i}(self):\n return {i}" for i in range(100)]) + type_class_code = f''' +class TypeClass: + """A type class for annotations.""" + + def __init__(self, value: int): + self.value = value + +{many_methods} +''' + type_class_path = package_dir / "types.py" + type_class_path.write_text(type_class_code, encoding="utf-8") + + # Main module uses TypeClass only in annotation (not instantiated) + # This triggers get_imported_class_definitions to extract the full class + main_code = """ +from mypackage.types import TypeClass + +def target_function(obj: TypeClass) -> int: + return obj.value +""" + main_path = package_dir / "main.py" + main_path.write_text(main_code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_function", + file_path=main_path, + parents=[], + ) + + # Use a testgen_token_limit that: + # - Is exceeded by full context with imported class (~1500 tokens) + # - Is exceeded even after removing docstrings + # - But fits when imported class is removed (~40 tokens) + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + testgen_token_limit=200, # Small limit to trigger imported class removal + ) + + # The testgen context should exist (didn't raise ValueError) + testgen_context = code_ctx.testgen_context.markdown + assert testgen_context, "Testgen context should not be empty" + + # The target function should still be there + assert "def target_function" in testgen_context, "Target function should be in testgen context" + + # The large imported class should NOT be included (removed due to token limit) + assert "class TypeClass" not in testgen_context, ( + "TypeClass should be removed from testgen context when exceeding token limit" + ) + + +def test_testgen_raises_when_all_fallbacks_fail(tmp_path: Path) -> None: + """Test that ValueError is raised when testgen context exceeds limit even after all fallbacks. + + This covers line 186 in code_context_extractor.py. + """ + # Create a function with a very long body that exceeds limits even without imports/docstrings + long_lines = [" x = 0"] + for i in range(200): + long_lines.append(f" x = x + {i}") + long_lines.append(" return x") + long_body = "\n".join(long_lines) + + code = f""" +def target_function(): +{long_body} +""" + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_function", + file_path=file_path, + parents=[], + ) + + # Use a very small testgen_token_limit that cannot fit even the base function + with pytest.raises(ValueError, match="Testgen code context has exceeded token limit"): + get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + testgen_token_limit=50, # Very small limit + ) + + +def test_get_external_base_class_inits_attribute_base(tmp_path: Path) -> None: + """Test handling of base class accessed as module.ClassName (ast.Attribute). + + This covers line 616 in code_context_extractor.py. + """ + # Use the standard import style which the code actually handles + code = """from collections import UserDict + +class MyDict(UserDict): + def custom_method(self): + return self.data +""" + code_path = tmp_path / "mydict.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + # Should extract UserDict __init__ + assert len(result.code_strings) == 1 + assert "class UserDict:" in result.code_strings[0].code + assert "def __init__" in result.code_strings[0].code + + +def test_get_external_base_class_inits_no_init_method(tmp_path: Path) -> None: + """Test handling when base class has no __init__ method. + + This covers line 641 in code_context_extractor.py. + """ + # Create a class inheriting from a class that doesn't have inspectable __init__ + code = """from typing import Protocol + +class MyProtocol(Protocol): + pass +""" + code_path = tmp_path / "myproto.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + # Protocol's __init__ can't be easily inspected, should handle gracefully + # Result may be empty or contain Protocol based on implementation + assert isinstance(result.code_strings, list) + + +def test_collect_names_from_annotation_attribute(tmp_path: Path) -> None: + """Test collect_names_from_annotation handles ast.Attribute annotations. + + This covers line 756 in code_context_extractor.py. + """ + # Use __import__ to avoid polluting the test file's detected imports + ast_mod = __import__("ast") + + # Parse code with type annotation using attribute access + code = "x: typing.List[int] = []" + tree = ast_mod.parse(code) + names: set[str] = set() + + # Find the annotation node + for node in ast_mod.walk(tree): + if isinstance(node, ast_mod.AnnAssign) and node.annotation: + collect_names_from_annotation(node.annotation, names) + break + + assert "typing" in names + + +def test_extract_imports_for_class_decorator_call_attribute(tmp_path: Path) -> None: + """Test extract_imports_for_class handles decorator calls with attribute access. + + This covers lines 707-708 in code_context_extractor.py. + """ + ast_mod = __import__("ast") + + code = """ +import functools + +@functools.lru_cache(maxsize=128) +class CachedClass: + pass +""" + tree = ast_mod.parse(code) + + # Find the class node + class_node = None + for node in ast_mod.walk(tree): + if isinstance(node, ast_mod.ClassDef): + class_node = node + break + + assert class_node is not None + result = extract_imports_for_class(tree, class_node, code) + + # Should include the functools import + assert "functools" in result + + +def test_annotated_assignment_in_read_writable(tmp_path: Path) -> None: + """Test that annotated assignments used by target function are in read-writable context. + + This covers lines 965-969 in code_context_extractor.py. + """ + code = """ +CONFIG_VALUE: int = 42 + +class MyClass: + def __init__(self): + self.x = CONFIG_VALUE + + def target_method(self): + return self.x +""" + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + ) + + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + ) + + # CONFIG_VALUE should be in read-writable context since it's used by __init__ + read_writable = code_ctx.read_writable_code.markdown + assert "CONFIG_VALUE" in read_writable + + +def test_imported_class_definitions_module_path_none(tmp_path: Path) -> None: + """Test handling when module_path is None in get_imported_class_definitions. + + This covers line 560 in code_context_extractor.py. + """ + # Create code that imports from a non-existent or unresolvable module + code = """ +from nonexistent_module_xyz import SomeClass + +class MyClass: + def method(self, obj: SomeClass): + pass +""" + code_path = tmp_path / "test.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_imported_class_definitions(context, tmp_path) + + # Should handle gracefully and return empty or partial results + assert isinstance(result.code_strings, list) + + +def test_get_imported_names_import_star(tmp_path: Path) -> None: + """Test get_imported_names handles import * correctly. + + This covers lines 808-809 and 824-825 in code_context_extractor.py. + """ + import libcst as cst + + # Test regular import * + # Note: "import *" is not valid Python, but "from x import *" is + from_import_star = cst.parse_statement("from os import *") + assert isinstance(from_import_star, cst.SimpleStatementLine) + import_node = from_import_star.body[0] + assert isinstance(import_node, cst.ImportFrom) + + from codeflash.context.code_context_extractor import get_imported_names + + result = get_imported_names(import_node) + assert result == {"*"} + + +def test_get_imported_names_aliased_import(tmp_path: Path) -> None: + """Test get_imported_names handles aliased imports correctly. + + This covers lines 812-813 and 828-829 in code_context_extractor.py. + """ + import libcst as cst + + from codeflash.context.code_context_extractor import get_imported_names + + # Test import with alias + import_stmt = cst.parse_statement("import numpy as np") + assert isinstance(import_stmt, cst.SimpleStatementLine) + import_node = import_stmt.body[0] + assert isinstance(import_node, cst.Import) + + result = get_imported_names(import_node) + assert "np" in result + + # Test from import with alias + from_import_stmt = cst.parse_statement("from os import path as ospath") + assert isinstance(from_import_stmt, cst.SimpleStatementLine) + from_import_node = from_import_stmt.body[0] + assert isinstance(from_import_node, cst.ImportFrom) + + result2 = get_imported_names(from_import_node) + assert "ospath" in result2 + + +def test_get_imported_names_dotted_import(tmp_path: Path) -> None: + """Test get_imported_names handles dotted imports correctly. + + This covers lines 816-822 in code_context_extractor.py. + """ + import libcst as cst + + from codeflash.context.code_context_extractor import get_imported_names + + # Test dotted import like "import os.path" + import_stmt = cst.parse_statement("import os.path") + assert isinstance(import_stmt, cst.SimpleStatementLine) + import_node = import_stmt.body[0] + assert isinstance(import_node, cst.Import) + + result = get_imported_names(import_node) + assert "os" in result + + +def test_used_name_collector_comprehensive(tmp_path: Path) -> None: + """Test UsedNameCollector handles various node types. + + This covers lines 767-801 in code_context_extractor.py. + """ + import libcst as cst + + from codeflash.context.code_context_extractor import UsedNameCollector + + code = """ +import os +from typing import List + +x: int = 1 +y = os.path.join("a", "b") + +class MyClass: + z = 10 + +def my_func(): + pass +""" + module = cst.parse_module(code) + collector = UsedNameCollector() + # In libcst, the walker traverses the module + cst.MetadataWrapper(module).visit(collector) + + # Check used names + assert "os" in collector.used_names + assert "int" in collector.used_names + assert "List" in collector.used_names + + # Check defined names + assert "x" in collector.defined_names + assert "y" in collector.defined_names + assert "MyClass" in collector.defined_names + assert "my_func" in collector.defined_names + + # Check external names (used but not defined) + external = collector.get_external_names() + assert "os" in external + assert "x" not in external # x is defined + + +def test_imported_class_with_base_in_same_module(tmp_path: Path) -> None: + """Test that imported classes with bases in the same module are extracted correctly. + + This covers line 528 in code_context_extractor.py - early return for already extracted. + """ + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + # Create a module with inheritance chain + module_code = """ +class BaseClass: + def __init__(self): + self.base = True + +class MiddleClass(BaseClass): + def __init__(self): + super().__init__() + self.middle = True + +class DerivedClass(MiddleClass): + def __init__(self): + super().__init__() + self.derived = True +""" + module_path = package_dir / "classes.py" + module_path.write_text(module_code, encoding="utf-8") + + # Main module imports and uses the derived class + main_code = """ +from mypackage.classes import DerivedClass + +def target_function(obj: DerivedClass) -> bool: + return obj.derived +""" + main_path = package_dir / "main.py" + main_path.write_text(main_code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=main_code, file_path=main_path)]) + result = get_imported_class_definitions(context, tmp_path) + + # Should extract the inheritance chain + all_code = "\n".join(cs.code for cs in result.code_strings) + assert "class BaseClass" in all_code or "class DerivedClass" in all_code + + +def test_get_imported_names_from_import_without_alias(tmp_path: Path) -> None: + """Test get_imported_names handles from imports without aliases. + + This covers lines 830-831 in code_context_extractor.py. + """ + import libcst as cst + + from codeflash.context.code_context_extractor import get_imported_names + + # Test from import without alias + from_import_stmt = cst.parse_statement("from os import path, getcwd") + assert isinstance(from_import_stmt, cst.SimpleStatementLine) + from_import_node = from_import_stmt.body[0] + assert isinstance(from_import_node, cst.ImportFrom) + + result = get_imported_names(from_import_node) + assert "path" in result + assert "getcwd" in result + + +def test_get_imported_names_regular_import(tmp_path: Path) -> None: + """Test get_imported_names handles regular imports. + + This covers lines 814-815 in code_context_extractor.py. + """ + import libcst as cst + + from codeflash.context.code_context_extractor import get_imported_names + + # Test regular import without alias + import_stmt = cst.parse_statement("import json") + assert isinstance(import_stmt, cst.SimpleStatementLine) + import_node = import_stmt.body[0] + assert isinstance(import_node, cst.Import) + + result = get_imported_names(import_node) + assert "json" in result + + +def test_augmented_assignment_not_in_context(tmp_path: Path) -> None: + """Test that augmented assignments are handled but not included unless used. + + This covers line 962-969 in code_context_extractor.py. + """ + code = """ +counter = 0 + +class MyClass: + def __init__(self): + global counter + counter += 1 + + def target_method(self): + return 42 +""" + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + ) + + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + ) + + # counter should be in context since __init__ uses it + read_writable = code_ctx.read_writable_code.markdown + assert "counter" in read_writable diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 04d83f13f..da83146a8 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -2119,7 +2119,6 @@ def new_function2(value): expected_code = """import numpy as np a = 6 - if 2<3: a=4 else: diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index 6d76e2bf6..b9112f047 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -1602,7 +1602,94 @@ def calculate_portfolio_metrics( # now the test should match and no diffs should be found assert len(diffs) == 0 assert matched - + + finally: + test_path.unlink(missing_ok=True) + fto_file_path.unlink(missing_ok=True) + + +def test_codeflash_capture_with_slots_class() -> None: + """Test that codeflash_capture works with classes that use __slots__ instead of __dict__.""" + test_code = """ +from code_to_optimize.tests.pytest.sample_code import SlotsClass +import unittest + +def test_slots_class(): + obj = SlotsClass(10, "test") + assert obj.x == 10 + assert obj.y == "test" +""" + test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() + tmp_dir_path = get_run_tmp_file(Path("test_return_values")) + sample_code = f""" +from codeflash.verification.codeflash_capture import codeflash_capture + +class SlotsClass: + __slots__ = ('x', 'y') + + @codeflash_capture(function_name="SlotsClass.__init__", tmp_dir_path="{tmp_dir_path.as_posix()}", tests_root="{test_dir.as_posix()}") + def __init__(self, x, y): + self.x = x + self.y = y +""" + test_file_name = "test_slots_class_temp.py" + test_path = test_dir / test_file_name + test_path_perf = test_dir / "test_slots_class_temp_perf.py" + + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = (Path(__file__).parent / "..").resolve() + sample_code_path = test_dir / "sample_code.py" + + try: + with test_path.open("w") as f: + f.write(test_code) + with sample_code_path.open("w") as f: + f.write(sample_code) + + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_type = TestType.EXISTING_UNIT_TEST + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + fto = FunctionToOptimize( + function_name="__init__", + file_path=sample_code_path, + parents=[FunctionParent(name="SlotsClass", type="ClassDef")], + ) + func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + # Test should pass and capture the slots values + assert len(test_results) == 1 + assert test_results[0].did_pass + # The return value should contain the slot values + assert test_results[0].return_value[0]["x"] == 10 + assert test_results[0].return_value[0]["y"] == "test" + finally: test_path.unlink(missing_ok=True) - fto_file_path.unlink(missing_ok=True) \ No newline at end of file + sample_code_path.unlink(missing_ok=True) \ No newline at end of file diff --git a/tests/test_get_read_writable_code.py b/tests/test_get_read_writable_code.py index d1eeb6e99..952479d3a 100644 --- a/tests/test_get_read_writable_code.py +++ b/tests/test_get_read_writable_code.py @@ -218,8 +218,28 @@ class Inner: def target(self): pass """ + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}) + expected = dedent(""" + class MyClass: + def method(self): + pass + + class Inner: + def target(self): + pass + """) + assert result.strip() == expected.strip() + + +def test_no_targets_found_raises_for_nonexistent() -> None: + """Test that ValueError is raised when the target function doesn't exist at all.""" + code = """ + class MyClass: + def method(self): + pass + """ with pytest.raises(ValueError, match="No target functions found in the provided code"): - parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}) + parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"NonExistent.target"}) def test_module_var() -> None: diff --git a/tests/test_multi_file_code_replacement.py b/tests/test_multi_file_code_replacement.py index e33e98d24..2d1f22509 100644 --- a/tests/test_multi_file_code_replacement.py +++ b/tests/test_multi_file_code_replacement.py @@ -124,7 +124,7 @@ def _get_string_usage(text: str) -> Usage: helper_file.unlink(missing_ok=True) main_file.unlink(missing_ok=True) - + expected_helper = """import re from collections.abc import Sequence diff --git a/tests/test_remove_unused_definitions.py b/tests/test_remove_unused_definitions.py index 8d09a95e1..edf11e7c5 100644 --- a/tests/test_remove_unused_definitions.py +++ b/tests/test_remove_unused_definitions.py @@ -481,3 +481,86 @@ def unused_function(): qualified_functions = {"get_platform_info", "get_loop_result"} result = remove_unused_definitions_by_function_names(code, qualified_functions) assert result.strip() == expected.strip() + + +def test_enum_attribute_access_dependency() -> None: + """Test that enum/class attribute access like MessageKind.VALUE is tracked as a dependency.""" + code = """ +from enum import Enum + +class MessageKind(Enum): + VALUE = "value" + OTHER = "other" + +class UnusedEnum(Enum): + UNUSED = "unused" + +UNUSED_VAR = 123 + +def process_message(kind): + match kind: + case MessageKind.VALUE: + return "got value" + case MessageKind.OTHER: + return "got other" + return "unknown" +""" + + expected = """ +from enum import Enum + +class MessageKind(Enum): + VALUE = "value" + OTHER = "other" + +class UnusedEnum(Enum): + UNUSED = "unused" + +def process_message(kind): + match kind: + case MessageKind.VALUE: + return "got value" + case MessageKind.OTHER: + return "got other" + return "unknown" +""" + + qualified_functions = {"process_message"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # MessageKind should be preserved because process_message uses MessageKind.VALUE + assert "class MessageKind" in result + # UNUSED_VAR should be removed + assert "UNUSED_VAR" not in result + assert result.strip() == expected.strip() + + +def test_attribute_access_does_not_track_attr_name() -> None: + """Test that self.x attribute access doesn't track 'x' as a dependency on module-level x.""" + code = """ +x = "module_level_x" +UNUSED_VAR = "unused" + +class MyClass: + def __init__(self): + self.x = 1 # This 'x' is an attribute, not a reference to module-level 'x' + + def get_x(self): + return self.x # This 'x' is also an attribute access +""" + + expected = """ +class MyClass: + def __init__(self): + self.x = 1 # This 'x' is an attribute, not a reference to module-level 'x' + + def get_x(self): + return self.x # This 'x' is also an attribute access +""" + + qualified_functions = {"MyClass.get_x", "MyClass.__init__"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # Module-level x should NOT be kept (self.x doesn't reference it) + assert 'x = "module_level_x"' not in result + # UNUSED_VAR should also be removed + assert "UNUSED_VAR" not in result + assert result.strip() == expected.strip()