diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 4ef96b308..be8321e7a 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -193,6 +193,75 @@ def optimize_python_code( # noqa: D417 console.rule() return [] + def augmented_optimize( # noqa: D417 + self, + source_code: str, + system_prompt: str, + user_prompt: str, + trace_id: str, + dependency_code: str | None = None, + n_candidates: int = 3, + ) -> list[OptimizedCandidate]: + """Optimize code with custom prompts via /ai/augmented-optimize endpoint. + + Parameters + ---------- + - source_code (str): The python code to optimize (markdown format). + - system_prompt (str): Custom system prompt for the LLM. + - user_prompt (str): Custom user prompt for the LLM. + - trace_id (str): Trace id of optimization run. + - dependency_code (str | None): Optional dependency code context. + - n_candidates (int): Number of candidates to generate (max 3). + + Returns + ------- + - list[OptimizedCandidate]: A list of Optimization Candidates. + + """ + logger.info("Generating augmented optimization candidates…") + console.rule() + start_time = time.perf_counter() + git_repo_owner, git_repo_name = safe_get_repo_owner_and_name() + + payload = { + "source_code": source_code, + "dependency_code": dependency_code, + "trace_id": trace_id, + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "n_candidates": min(n_candidates, 3), + "python_version": platform.python_version(), + "codeflash_version": codeflash_version, + "current_username": get_last_commit_author_if_pr_exists(None), + "repo_owner": git_repo_owner, + "repo_name": git_repo_name, + } + logger.debug(f"Sending augmented optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}") + + try: + response = self.make_ai_service_request("/augmented-optimize", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating augmented optimization candidates: {e}") + ph("cli-augmented-optimize-error-caught", {"error": str(e)}) + console.rule() + return [] + + if response.status_code == 200: + optimizations_json = response.json()["optimizations"] + end_time = time.perf_counter() + logger.debug(f"!lsp|Generating augmented optimizations took {end_time - start_time:.2f} seconds.") + logger.info(f"!lsp|Received {len(optimizations_json)} augmented optimization candidates.") + console.rule() + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.AUGMENTED) + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating augmented optimization candidates: {response.status_code} - {error}") + ph("cli-augmented-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return [] + def get_jit_rewritten_code( # noqa: D417 self, source_code: str, trace_id: str ) -> list[OptimizedCandidate]: diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index df0ae57ce..4da82bc6e 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -6,6 +6,7 @@ from codeflash.cli_cmds import logging_config from codeflash.cli_cmds.cli_common import apologize_and_exit +from codeflash.cli_cmds.cmd_create_pr import create_pr as cmd_create_pr from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions from codeflash.cli_cmds.console import logger from codeflash.cli_cmds.extension import install_vscode_extension @@ -57,6 +58,22 @@ def parse_args() -> Namespace: help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.", ) + # create-pr subcommand for creating PRs from augmented optimization results + create_pr_parser = subparsers.add_parser("create-pr", help="Create a PR from previously applied optimizations") + create_pr_parser.set_defaults(func=cmd_create_pr) + create_pr_parser.add_argument( + "--results-file", + type=str, + default="codeflash_phase1_results.json", + help="Path to augmented output JSON file (default: codeflash_phase1_results.json)", + ) + create_pr_parser.add_argument( + "--function", type=str, help="Function name (required if multiple functions in results)" + ) + create_pr_parser.add_argument( + "--git-remote", type=str, default="origin", help="Git remote to use for PR creation (default: origin)" + ) + parser.add_argument("--file", help="Try to optimize only this file") parser.add_argument("--function", help="Try to optimize only this function within the given file path") parser.add_argument( @@ -117,6 +134,22 @@ def parse_args() -> Namespace: parser.add_argument( "--effort", type=str, help="Effort level for optimization", choices=["low", "medium", "high"], default="medium" ) + parser.add_argument( + "--augmented", + action="store_true", + help="Enable augmented optimization mode for two-phase optimization workflow", + ) + parser.add_argument( + "--augmented-prompt-file", + type=str, + help="Path to YAML file with custom system_prompt and user_prompt for Phase 2", + ) + parser.add_argument( + "--augmented-output", + type=str, + default="codeflash_phase1_results.json", + help="Path to write Phase 1 results JSON (default: codeflash_phase1_results.json)", + ) args, unknown_args = parser.parse_known_args() sys.argv[:] = [sys.argv[0], *unknown_args] @@ -175,6 +208,15 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace: "Async function optimization is now enabled by default." ) + if args.augmented_prompt_file and not args.augmented: + exit_with_message("--augmented-prompt-file requires --augmented flag", error_on_exit=True) + + if args.augmented_prompt_file: + prompt_file = Path(args.augmented_prompt_file) + if not prompt_file.exists(): + exit_with_message(f"Augmented prompt file {args.augmented_prompt_file} does not exist", error_on_exit=True) + args.augmented_prompt_file = prompt_file.resolve() + return args diff --git a/codeflash/cli_cmds/cmd_create_pr.py b/codeflash/cli_cmds/cmd_create_pr.py new file mode 100644 index 000000000..902c6ff8e --- /dev/null +++ b/codeflash/cli_cmds/cmd_create_pr.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import json +import re +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash.cli_cmds.console import logger +from codeflash.code_utils.code_utils import exit_with_message +from codeflash.code_utils.git_utils import git_root_dir +from codeflash.models.models import Phase1Output, TestResults +from codeflash.result.create_pr import check_create_pr +from codeflash.result.explanation import Explanation + +if TYPE_CHECKING: + from argparse import Namespace + +# Pattern to extract file path from markdown code block: ```python:path/to/file.py +MARKDOWN_FILE_PATH_PATTERN = re.compile(r"```python:([^\n]+)") + + +def extract_file_path_from_markdown(markdown_code: str) -> str | None: + """Extract the file path from markdown code block format. + + Format: ```python:path/to/file.py + """ + match = MARKDOWN_FILE_PATH_PATTERN.search(markdown_code) + if match: + return match.group(1).strip() + return None + + +def extract_code_from_markdown(markdown_code: str) -> str: + r"""Extract the code content from markdown code block. + + Removes the ```python:path\n ... ``` wrapper. + """ + # Remove opening markdown fence with optional path + code = re.sub(r"^```python(?::[^\n]*)?\n", "", markdown_code) + # Remove closing fence + return re.sub(r"\n```$", "", code) + + +def create_pr(args: Namespace) -> None: + """Create a PR from previously applied optimizations.""" + results_file = Path(args.results_file) + + if not results_file.exists(): + exit_with_message(f"Results file not found: {results_file}", error_on_exit=True) + + # Load and parse results + with results_file.open(encoding="utf-8") as f: + data = json.load(f) + + try: + output = Phase1Output.model_validate(data) + except Exception as e: + exit_with_message(f"Failed to parse results file: {e}", error_on_exit=True) + + # Find the function result + if len(output.functions) == 0: + exit_with_message("No functions in results file", error_on_exit=True) + + if len(output.functions) > 1 and not args.function: + func_names = [f.function_name for f in output.functions] + exit_with_message( + f"Multiple functions in results. Specify one with --function: {func_names}", error_on_exit=True + ) + + func_result = output.functions[0] + if args.function: + func_result = next((f for f in output.functions if f.function_name == args.function), None) + if not func_result: + exit_with_message(f"Function {args.function} not found in results", error_on_exit=True) + assert func_result is not None # for type checker - exit_with_message doesn't return + + if not func_result.best_candidate_id: + exit_with_message("No successful optimization found in results", error_on_exit=True) + + # Get file path - prefer explicit field, fall back to extracting from markdown + file_path_str = func_result.file_path + if not file_path_str: + file_path_str = extract_file_path_from_markdown(func_result.original_source_code) + + if not file_path_str: + exit_with_message( + "Could not determine file path from results. Results file may be from an older version of codeflash.", + error_on_exit=True, + ) + assert file_path_str is not None # for type checker - exit_with_message doesn't return + + file_path = Path(file_path_str) + if not file_path.exists(): + exit_with_message(f"Source file not found: {file_path}", error_on_exit=True) + + # Read current (optimized) file content + current_content = file_path.read_text(encoding="utf-8") + + # Extract original code (strip markdown) + original_code = extract_code_from_markdown(func_result.original_source_code) + + # Get the best candidate's explanation + best_explanation = func_result.best_candidate_explanation + if not best_explanation: + # Fall back to the candidate's explanation if the final explanation wasn't captured + best_candidate = next( + (c for c in func_result.candidates if c.optimization_id == func_result.best_candidate_id), None + ) + best_explanation = best_candidate.explanation if best_candidate else "Optimization applied" + + # Build Explanation object for PR creation + explanation = Explanation( + raw_explanation_message=best_explanation, + winning_behavior_test_results=TestResults(), + winning_benchmarking_test_results=TestResults(), + original_runtime_ns=func_result.original_runtime_ns or 0, + best_runtime_ns=func_result.best_runtime_ns or func_result.original_runtime_ns or 0, + function_name=func_result.function_name, + file_path=file_path, + ) + + logger.info(f"Creating PR for optimized function: {func_result.function_name}") + logger.info(f"File: {file_path}") + if func_result.best_speedup_ratio: + logger.info(f"Speedup: {func_result.best_speedup_ratio * 100:.1f}%") + + # Call existing PR creation + check_create_pr( + original_code={file_path: original_code}, + new_code={file_path: current_content}, + explanation=explanation, + existing_tests_source=func_result.existing_tests_source or "", + generated_original_test_source="", + function_trace_id=func_result.trace_id, + coverage_message="", + replay_tests=func_result.replay_tests_source or "", + concolic_tests=func_result.concolic_tests_source or "", + optimization_review="", + root_dir=git_root_dir(), + git_remote=getattr(args, "git_remote", None), + precomputed_test_report=func_result.test_report, + precomputed_loop_count=func_result.loop_count, + ) + + # Cleanup results file after successful PR creation + results_file.unlink() + logger.info(f"Cleaned up results file: {results_file}") diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 6ddfe763a..5f296a8a3 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -2,6 +2,7 @@ import ast import time +from collections import defaultdict from dataclasses import dataclass from importlib.util import find_spec from itertools import chain @@ -25,6 +26,157 @@ from codeflash.models.models import FunctionSource +class NameCollector(cst.CSTVisitor): + """Collects all Name nodes referenced in a CST expression. + + Used to find what names an assignment depends on (e.g., function calls in the RHS). + Properly scopes comprehension variables and lambda parameters so they are not + treated as external dependencies. + """ + + def __init__(self) -> None: + super().__init__() + self.names: set[str] = set() + # Stack of sets tracking locally bound names (comprehension vars, lambda params) + self._local_names_stack: list[set[str]] = [] + + def _is_local_name(self, name: str) -> bool: + """Check if a name is bound in any enclosing local scope.""" + return any(name in scope for scope in self._local_names_stack) + + def visit_Name(self, node: cst.Name) -> None: + if not self._is_local_name(node.value): + self.names.add(node.value) + + # --- Comprehension handling --- + def visit_ListComp(self, node: cst.ListComp) -> bool: + """Handle list comprehensions: [expr for x in iterable if cond].""" + self._visit_comprehension(node.for_in, node.elt) + return False # Don't visit children - we handle them manually + + def visit_SetComp(self, node: cst.SetComp) -> bool: + """Handle set comprehensions: {expr for x in iterable if cond}.""" + self._visit_comprehension(node.for_in, node.elt) + return False + + def visit_DictComp(self, node: cst.DictComp) -> bool: + """Handle dict comprehensions: {key: value for x in iterable if cond}.""" + # For dict comps, we need to visit both key and value + self._visit_dict_comprehension(node.for_in, node.key, node.value) + return False + + def visit_GeneratorExp(self, node: cst.GeneratorExp) -> bool: + """Handle generator expressions: (expr for x in iterable if cond).""" + self._visit_comprehension(node.for_in, node.elt) + return False + + def _extract_comp_target_names(self, target: cst.BaseExpression) -> set[str]: + """Extract names bound by a comprehension target (e.g., x in 'for x in items').""" + return _extract_names_from_target(target) + + def _visit_comprehension(self, for_in: cst.CompFor, elt: cst.BaseExpression) -> None: + """Visit a comprehension, properly scoping the loop variables.""" + # Create a new scope for this comprehension + local_scope: set[str] = set() + self._local_names_stack.append(local_scope) + + # Process all for..in clauses, collecting bound names as we go + self._visit_comp_for(for_in, local_scope) + + # Visit the element expression (all bound names are now in scope) + elt.visit(self) + + # Pop the scope + self._local_names_stack.pop() + + def _visit_dict_comprehension( + self, for_in: cst.CompFor, key: cst.BaseExpression, value: cst.BaseExpression + ) -> None: + """Visit a dict comprehension, properly scoping the loop variables.""" + local_scope: set[str] = set() + self._local_names_stack.append(local_scope) + + self._visit_comp_for(for_in, local_scope) + + # Visit both key and value expressions + key.visit(self) + value.visit(self) + + self._local_names_stack.pop() + + def _visit_comp_for(self, comp_for: cst.CompFor, local_scope: set[str]) -> None: + """Recursively visit CompFor nodes, adding bound names to scope.""" + # First, visit the iterable (before adding target names to scope!) + # This is important: `x` in `[x for x in x]` - the RHS `x` is external + comp_for.iter.visit(self) + + # Add the target names to the local scope + target_names = self._extract_comp_target_names(comp_for.target) + local_scope.update(target_names) + + # Visit any conditions (ifs) + for if_clause in comp_for.ifs: + if_clause.test.visit(self) + + # Visit nested for..in if present + if comp_for.inner_for_in is not None: + self._visit_comp_for(comp_for.inner_for_in, local_scope) + + # --- Lambda handling --- + def visit_Lambda(self, node: cst.Lambda) -> bool: + """Handle lambda expressions: lambda x, y: x + y + z.""" + # Collect parameter names + param_names = self._extract_lambda_param_names(node.params) + + # Push a new scope with these parameters + self._local_names_stack.append(param_names) + + # Visit the lambda body + node.body.visit(self) + + # Pop the scope + self._local_names_stack.pop() + + return False # Don't visit children - we handle them manually + + def _extract_lambda_param_names(self, params: cst.Parameters) -> set[str]: + """Extract all parameter names from a lambda's parameters.""" + names: set[str] = set() + + # Regular positional/keyword params + for param in params.params: + names.add(param.name.value) + + # *args + if params.star_arg and isinstance(params.star_arg, cst.Param): + names.add(params.star_arg.name.value) + + # Keyword-only params + for param in params.kwonly_params: + names.add(param.name.value) + + # **kwargs + if params.star_kwarg: + names.add(params.star_kwarg.name.value) + + # Positional-only params (Python 3.8+) + for param in params.posonly_params: + names.add(param.name.value) + + return names + + +def get_names_in_expression(node: cst.BaseExpression) -> set[str]: + """Extract all names referenced in a CST expression. + + Comprehension variables and lambda parameters are properly scoped and + are NOT included as external dependencies. + """ + collector = NameCollector() + node.visit(collector) + return collector.names + + class GlobalFunctionCollector(cst.CSTVisitor): """Collects all module-level function definitions (not inside classes or other functions).""" @@ -131,7 +283,13 @@ def _collect(n: cst.CSTNode) -> None: class GlobalAssignmentCollector(cst.CSTVisitor): - """Collects all global assignment statements.""" + """Collects all global assignment statements. + + Only collects simple assignments at module level, NOT inside: + - Functions/classes (scope_depth) + - If/else blocks (if_else_depth) + - For/while loops, with statements, try blocks (compound_depth) + """ def __init__(self) -> None: super().__init__() @@ -140,6 +298,9 @@ def __init__(self) -> None: # Track scope depth to identify global assignments self.scope_depth = 0 self.if_else_depth = 0 + # Track other compound statements (for, while, with, try) where assignments + # inside them depend on loop variables or context managers + self.compound_depth = 0 def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: # noqa: ARG002 self.scope_depth += 1 @@ -166,9 +327,44 @@ def visit_Else(self, node: cst.Else) -> Optional[bool]: # noqa: ARG002 # Else blocks are already counted as part of the if statement return True + def visit_For(self, node: cst.For) -> Optional[bool]: # noqa: ARG002 + self.compound_depth += 1 + return True + + def leave_For(self, original_node: cst.For) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_While(self, node: cst.While) -> Optional[bool]: # noqa: ARG002 + self.compound_depth += 1 + return True + + def leave_While(self, original_node: cst.While) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_With(self, node: cst.With) -> Optional[bool]: # noqa: ARG002 + self.compound_depth += 1 + return True + + def leave_With(self, original_node: cst.With) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_Try(self, node: cst.Try) -> Optional[bool]: # noqa: ARG002 + self.compound_depth += 1 + return True + + def leave_Try(self, original_node: cst.Try) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_Match(self, node: cst.Match) -> Optional[bool]: # noqa: ARG002 + self.compound_depth += 1 + return True + + def leave_Match(self, original_node: cst.Match) -> None: # noqa: ARG002 + self.compound_depth -= 1 + def visit_Assign(self, node: cst.Assign) -> Optional[bool]: - # Only process global assignments (not inside functions, classes, etc.) - if self.scope_depth == 0 and self.if_else_depth == 0: # We're at module level + # Only process global assignments (not inside functions, classes, loops, etc.) + if self.scope_depth == 0 and self.if_else_depth == 0 and self.compound_depth == 0: for target in node.targets: if isinstance(target.target, cst.Name): name = target.target.value @@ -221,7 +417,13 @@ def find_insertion_index_after_imports(node: cst.Module) -> int: class GlobalAssignmentTransformer(cst.CSTTransformer): - """Transforms global assignments in the original file with those from the new file.""" + """Transforms global assignments in the original file with those from the new file. + + Only transforms simple assignments at module level, NOT inside: + - Functions/classes (scope_depth) + - If/else blocks (if_else_depth) + - For/while loops, with statements, try blocks (compound_depth) + """ def __init__(self, new_assignments: dict[str, cst.Assign | cst.AnnAssign], new_assignment_order: list[str]) -> None: super().__init__() @@ -230,6 +432,7 @@ def __init__(self, new_assignments: dict[str, cst.Assign | cst.AnnAssign], new_a self.processed_assignments: set[str] = set() self.scope_depth = 0 self.if_else_depth = 0 + self.compound_depth = 0 def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002 self.scope_depth += 1 @@ -256,8 +459,43 @@ def visit_Else(self, node: cst.Else) -> None: # Else blocks are already counted as part of the if statement pass + def visit_For(self, node: cst.For) -> None: # noqa: ARG002 + self.compound_depth += 1 + + def leave_For(self, original_node: cst.For, updated_node: cst.For) -> cst.For: # noqa: ARG002 + self.compound_depth -= 1 + return updated_node + + def visit_While(self, node: cst.While) -> None: # noqa: ARG002 + self.compound_depth += 1 + + def leave_While(self, original_node: cst.While, updated_node: cst.While) -> cst.While: # noqa: ARG002 + self.compound_depth -= 1 + return updated_node + + def visit_With(self, node: cst.With) -> None: # noqa: ARG002 + self.compound_depth += 1 + + def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.With: # noqa: ARG002 + self.compound_depth -= 1 + return updated_node + + def visit_Try(self, node: cst.Try) -> None: # noqa: ARG002 + self.compound_depth += 1 + + def leave_Try(self, original_node: cst.Try, updated_node: cst.Try) -> cst.Try: # noqa: ARG002 + self.compound_depth -= 1 + return updated_node + + def visit_Match(self, node: cst.Match) -> None: # noqa: ARG002 + self.compound_depth += 1 + + def leave_Match(self, original_node: cst.Match, updated_node: cst.Match) -> cst.Match: # noqa: ARG002 + self.compound_depth -= 1 + return updated_node + def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode: - if self.scope_depth > 0 or self.if_else_depth > 0: + if self.scope_depth > 0 or self.if_else_depth > 0 or self.compound_depth > 0: return updated_node # Check if this is a global assignment we need to replace @@ -287,7 +525,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c # Add any new assignments that weren't in the original file new_statements = list(updated_node.body) - # Find assignments to append + # Find assignments to append (with their names for dependency analysis) assignments_to_append = [ (name, self.new_assignments[name]) for name in self.new_assignment_order @@ -297,58 +535,41 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c if not assignments_to_append: return updated_node.with_changes(body=new_statements) - # Collect all class and function names defined in the module - # These are the names that assignments might reference - module_defined_names: set[str] = set() - for stmt in new_statements: - if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)): - module_defined_names.add(stmt.name.value) - - # Partition assignments: those that reference module definitions go at the end, - # those that don't can go right after imports - assignments_after_imports: list[tuple[str, cst.Assign | cst.AnnAssign]] = [] - assignments_after_definitions: list[tuple[str, cst.Assign | cst.AnnAssign]] = [] - - for name, assignment in assignments_to_append: - # Get the value being assigned - if isinstance(assignment, (cst.Assign, cst.AnnAssign)) and assignment.value is not None: - value_node = assignment.value - else: - # No value to analyze, safe to place after imports - assignments_after_imports.append((name, assignment)) - continue - - # Collect names referenced in the assignment value - referenced_names = collect_referenced_names(value_node) - - # Check if any referenced names are module-level definitions - if referenced_names & module_defined_names: - # This assignment references a class/function, place it after definitions - assignments_after_definitions.append((name, assignment)) - else: - # Safe to place right after imports - assignments_after_imports.append((name, assignment)) - - # Insert assignments that don't depend on module definitions right after imports - if assignments_after_imports: - insert_index = find_insertion_index_after_imports(updated_node) + # Build a map of where functions/classes/assignments are defined in the module + # Maps name -> index AFTER the definition (i.e., the first valid insertion point) + definition_positions: dict[str, int] = {} + for i, stmt in enumerate(new_statements): + if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): + definition_positions[stmt.name.value] = i + 1 + # Get all names defined by this statement (handles simple assignments, + # compound statements like for/while/with/try, etc.) + for name in get_statement_defined_names(stmt): + definition_positions[name] = i + 1 + + # Find the default insertion index (after imports) + default_insert_index = find_insertion_index_after_imports(updated_node) + + # For each assignment, determine its minimum insertion index + # by finding the max position of all functions/names it depends on + assignments_by_insert_index: dict[int, list[cst.Assign]] = defaultdict(list) + for _name, assignment in assignments_to_append: + # Get names referenced in the assignment's value (RHS) + dependencies = get_names_in_expression(assignment.value) + + # Find the minimum insertion index (after all dependencies are defined) + min_insert_index = default_insert_index + for dep_name in dependencies: + if dep_name in definition_positions: + min_insert_index = max(min_insert_index, definition_positions[dep_name]) + + assignments_by_insert_index[min_insert_index].append(assignment) + + # Insert assignments at their appropriate positions + # Process in reverse order of indices to avoid offset issues when inserting + for insert_index in sorted(assignments_by_insert_index.keys(), reverse=True): + assignments = assignments_by_insert_index[insert_index] assignment_lines = [ - cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) - for _, assignment in assignments_after_imports - ] - new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) - - # Insert assignments that depend on module definitions after all class/function definitions - if assignments_after_definitions: - # Find the position after the last function or class definition - insert_index = find_insertion_index_after_imports(cst.Module(body=new_statements)) - for i, stmt in enumerate(new_statements): - if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): - insert_index = i + 1 - - assignment_lines = [ - cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) - for _, assignment in assignments_after_definitions + cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) for assignment in assignments ] new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) @@ -388,13 +609,68 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c return updated_node.with_changes(body=new_statements) +class FunctionDefCollector(cst.CSTVisitor): + """Collects all top-level function definitions from a module. + + Used to find new helper functions in optimized code that need to be transferred. + """ + + def __init__(self) -> None: + super().__init__() + self.function_defs: dict[str, cst.FunctionDef] = {} + self.function_order: list[str] = [] + self.in_class = False + + def visit_ClassDef(self, node: cst.ClassDef) -> bool: # noqa: ARG002 + self.in_class = True + return False # Don't visit inside classes + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002 + self.in_class = False + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + if not self.in_class: + name = node.name.value + self.function_defs[name] = node + self.function_order.append(name) + return False # Don't visit inside functions + + +class FunctionNameCollector(cst.CSTVisitor): + """Collects all top-level function and class names from a module.""" + + def __init__(self) -> None: + super().__init__() + self.names: set[str] = set() + self.in_class = False + + def visit_ClassDef(self, node: cst.ClassDef) -> bool: + if not self.in_class: + self.names.add(node.name.value) + self.in_class = True + return False + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002 + self.in_class = False + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + if not self.in_class: + self.names.add(node.name.value) + return False + + class GlobalStatementCollector(cst.CSTVisitor): - """Visitor that collects all global statements (excluding imports and functions/classes).""" + """Visitor that collects all global statements (excluding imports and functions/classes). + + Collects both simple statements and compound statements (for, while, with, try) + at module level. + """ def __init__(self) -> None: super().__init__() - self.global_statements = [] + self.global_statements: list[cst.BaseStatement] = [] self.in_function_or_class = False + self.compound_depth = 0 def visit_ClassDef(self, node: cst.ClassDef) -> bool: # noqa: ARG002 # Don't visit inside classes @@ -412,13 +688,69 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # noqa: ARG002 def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002 self.in_function_or_class = False + def visit_For(self, node: cst.For) -> Optional[bool]: + if not self.in_function_or_class and self.compound_depth == 0: + self.global_statements.append(node) + self.compound_depth += 1 + return False # Don't visit children - we collect the whole node + + def leave_For(self, original_node: cst.For) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_While(self, node: cst.While) -> Optional[bool]: + if not self.in_function_or_class and self.compound_depth == 0: + self.global_statements.append(node) + self.compound_depth += 1 + return False # Don't visit children - we collect the whole node + + def leave_While(self, original_node: cst.While) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_With(self, node: cst.With) -> Optional[bool]: + if not self.in_function_or_class and self.compound_depth == 0: + self.global_statements.append(node) + self.compound_depth += 1 + return False # Don't visit children - we collect the whole node + + def leave_With(self, original_node: cst.With) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_Try(self, node: cst.Try) -> Optional[bool]: + if not self.in_function_or_class and self.compound_depth == 0: + self.global_statements.append(node) + self.compound_depth += 1 + return False # Don't visit children - we collect the whole node + + def leave_Try(self, original_node: cst.Try) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_Match(self, node: cst.Match) -> Optional[bool]: + if not self.in_function_or_class and self.compound_depth == 0: + self.global_statements.append(node) + self.compound_depth += 1 + return False # Don't visit children - we collect the whole node + + def leave_Match(self, original_node: cst.Match) -> None: # noqa: ARG002 + self.compound_depth -= 1 + def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: - if not self.in_function_or_class: + if not self.in_function_or_class and self.compound_depth == 0: for statement in node.body: - # Skip imports and assignments (both regular and annotated) - if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign, cst.AnnAssign)): + # Skip imports + if isinstance(statement, (cst.Import, cst.ImportFrom)): + continue + # Skip simple name assignments (handled by GlobalAssignmentCollector) + # But include tuple unpacking and chained assignments + if isinstance(statement, cst.Assign): + # Check if it's a simple single-name assignment + if len(statement.targets) == 1 and isinstance(statement.targets[0].target, cst.Name): + continue + # Tuple unpacking, chained assignment, etc. - include it self.global_statements.append(node) break + # Include other statement types (AnnAssign, Expr, etc.) + self.global_statements.append(node) + break class LastImportFinder(cst.CSTVisitor): @@ -511,7 +843,349 @@ def visit_Try(self, node: cst.Try) -> None: self._collect_imports_from_block(node.body) -def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]: +class FunctionDefInserter(cst.CSTTransformer): + """Transformer that inserts new function definitions into a module. + + New function definitions are inserted after imports but before any code that depends on them. + """ + + def __init__(self, function_defs: list[cst.FunctionDef], last_import_line: int) -> None: + super().__init__() + self.function_defs = function_defs + self.last_import_line = last_import_line + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 + if not self.function_defs: + return updated_node + + updated_body = list(updated_node.body) + + # Insert function definitions after imports (or at beginning if no imports) + if self.last_import_line == 0: + # No imports, insert at the beginning + for j, func_def in enumerate(self.function_defs): + updated_body.insert(j, func_def) + else: + # Find insertion point after last import + insertion_index = 0 + for i in range(len(updated_body)): + if i + 1 == self.last_import_line: + insertion_index = i + 1 + break + for j, func_def in enumerate(self.function_defs): + updated_body.insert(insertion_index + j, func_def) + + return updated_node.with_changes(body=updated_body) + + +def get_statement_dependencies(stmt: cst.BaseStatement) -> set[str]: + """Extract all names that a statement depends on (names used on the RHS). + + Note: Type annotations are NOT included as dependencies since they only affect + static type checking and don't cause runtime NameErrors. + """ + deps: set[str] = set() + if isinstance(stmt, cst.SimpleStatementLine): + for body_item in stmt.body: + if isinstance(body_item, cst.Assign): + # Get names used in the value (RHS) + deps.update(get_names_in_expression(body_item.value)) + elif isinstance(body_item, cst.AnnAssign) and body_item.value is not None: + # Only get names from the value, NOT from the annotation + # Type annotations don't create runtime dependencies + deps.update(get_names_in_expression(body_item.value)) + return deps + + +def _extract_names_from_target(target: cst.BaseExpression) -> set[str]: + """Extract all names from an assignment target (handles tuples, names, etc.).""" + names: set[str] = set() + if isinstance(target, cst.Name): + names.add(target.value) + elif isinstance(target, (cst.Tuple, cst.List)): + for element in target.elements: + if isinstance(element, (cst.Element, cst.StarredElement)): + names.update(_extract_names_from_target(element.value)) + return names + + +def _extract_names_from_pattern(pattern: cst.MatchPattern) -> set[str]: + """Extract all names bound by a match pattern (Python 3.10+). + + Match patterns can bind names in several ways: + - MatchAs: `case x:` or `case _ as x:` + - MatchStar: `case [*rest]:` + - MatchMapping: `case {"key": value}:` + - MatchSequence/MatchList/MatchTuple: `case [x, y]:` + - MatchClass: `case Point(x=px, y=py):` + - MatchOr: `case x | y:` (names must be same in all alternatives) + """ + names: set[str] = set() + + if isinstance(pattern, cst.MatchAs): + # `case x:` or `case _ as x:` - the name is bound + if pattern.name is not None: + names.add(pattern.name.value) + # Also check nested pattern (e.g., `case [a, b] as both:`) + if pattern.pattern is not None: + names.update(_extract_names_from_pattern(pattern.pattern)) + + elif isinstance(pattern, cst.MatchStar): + # `case [first, *rest]:` - rest is bound + if pattern.name is not None: + names.add(pattern.name.value) + + elif isinstance(pattern, (cst.MatchSequence, cst.MatchList, cst.MatchTuple)): + # `case [x, y]:` or `case (x, y):` + for element in pattern.patterns: + if isinstance(element, cst.MatchSequenceElement): + names.update(_extract_names_from_pattern(element.value)) + elif isinstance(element, cst.MatchStar) and element.name is not None: + names.add(element.name.value) + + elif isinstance(pattern, cst.MatchMapping): + # `case {"key": value}:` + for element in pattern.elements: + if isinstance(element, cst.MatchMappingElement): + names.update(_extract_names_from_pattern(element.pattern)) + # Also handle `case {**rest}:` + if pattern.rest is not None: + names.add(pattern.rest.value) + + elif isinstance(pattern, cst.MatchClass): + # `case Point(x, y):` or `case Point(x=px, y=py):` + for arg in pattern.patterns: + if isinstance(arg, cst.MatchSequenceElement): + names.update(_extract_names_from_pattern(arg.value)) + for kwarg in pattern.keywords: + if isinstance(kwarg, cst.MatchKeywordElement): + names.update(_extract_names_from_pattern(kwarg.pattern)) + + elif isinstance(pattern, cst.MatchOr): + # `case x | y:` - all alternatives must bind same names + for element in pattern.patterns: + if isinstance(element, cst.MatchOrElement): + names.update(_extract_names_from_pattern(element.pattern)) + + # MatchValue and MatchSingleton don't bind names + + return names + + +def _collect_defined_names_from_body(body: cst.BaseSuite) -> set[str]: + """Recursively collect all names defined by assignments inside a code block. + + This is used to find names defined inside compound statements (for, while, with, try). + """ + names: set[str] = set() + if isinstance(body, cst.IndentedBlock): + for stmt in body.body: + names.update(get_statement_defined_names(stmt)) + return names + + +def get_statement_defined_names(stmt: cst.BaseStatement) -> set[str]: + """Get all names defined by a statement (LHS of assignment). + + For compound statements (for, while, with, try, if), this also collects + names defined inside their bodies. + """ + names: set[str] = set() + if isinstance(stmt, cst.SimpleStatementLine): + for body_item in stmt.body: + if isinstance(body_item, cst.Assign): + # Handle chained assignments: a = b = c = 5 + for target_node in body_item.targets: + names.update(_extract_names_from_target(target_node.target)) + elif isinstance(body_item, cst.AnnAssign): + names.update(_extract_names_from_target(body_item.target)) + elif isinstance(stmt, cst.For): + # For loop target variable + names.update(_extract_names_from_target(stmt.target)) + # Names defined inside the for loop body + names.update(_collect_defined_names_from_body(stmt.body)) + if stmt.orelse: + names.update(_collect_defined_names_from_body(stmt.orelse.body)) + elif isinstance(stmt, cst.While): + # Names defined inside the while loop body + names.update(_collect_defined_names_from_body(stmt.body)) + if stmt.orelse: + names.update(_collect_defined_names_from_body(stmt.orelse.body)) + elif isinstance(stmt, cst.With): + # With statement can bind names via 'as' clause + for item in stmt.items: + if item.asname: + names.update(_extract_names_from_target(item.asname.name)) + # Names defined inside the with body + names.update(_collect_defined_names_from_body(stmt.body)) + elif isinstance(stmt, cst.Try): + # Names defined in try/except/else/finally bodies + names.update(_collect_defined_names_from_body(stmt.body)) + for handler in stmt.handlers: + if handler.name: + names.add(handler.name.value) + names.update(_collect_defined_names_from_body(handler.body)) + if stmt.orelse: + names.update(_collect_defined_names_from_body(stmt.orelse.body)) + if stmt.finalbody: + names.update(_collect_defined_names_from_body(stmt.finalbody.body)) + elif isinstance(stmt, cst.If): + # Names defined inside if/elif/else bodies + names.update(_collect_defined_names_from_body(stmt.body)) + if stmt.orelse: + if isinstance(stmt.orelse, cst.Else): + names.update(_collect_defined_names_from_body(stmt.orelse.body)) + elif isinstance(stmt.orelse, cst.If): + # elif branch + names.update(get_statement_defined_names(stmt.orelse)) + elif isinstance(stmt, cst.Match): + # Match statement (Python 3.10+) can bind names through patterns + for case in stmt.cases: + # Extract names bound by the pattern + names.update(_extract_names_from_pattern(case.pattern)) + # Names defined inside the case body + names.update(_collect_defined_names_from_body(case.body)) + return names + + +def find_last_definition_index(name: str, body: list[cst.BaseStatement]) -> int: + """Find the index of the last statement that defines a given name.""" + last_idx = -1 + for i, stmt in enumerate(body): + defined_names = get_statement_defined_names(stmt) + if name in defined_names: + last_idx = i + return last_idx + + +def _sort_statements_by_dependencies(statements: list[cst.BaseStatement]) -> list[cst.BaseStatement]: + """Sort statements so that definitions come before their dependents. + + Uses a stable topological sort - statements without dependencies on each other + maintain their original relative order. + """ + if len(statements) <= 1: + return statements + + # Build a map of defined names to statement index + name_to_idx: dict[str, int] = {} + for i, stmt in enumerate(statements): + for name in get_statement_defined_names(stmt): + name_to_idx[name] = i + + # Build dependency graph: idx -> set of indices it depends on + deps_graph: dict[int, set[int]] = {i: set() for i in range(len(statements))} + for i, stmt in enumerate(statements): + for dep_name in get_statement_dependencies(stmt): + if dep_name in name_to_idx: + dep_idx = name_to_idx[dep_name] + if dep_idx != i: # Don't add self-dependency + deps_graph[i].add(dep_idx) + + # Kahn's algorithm for topological sort with stable ordering + in_degree = {i: len(deps) for i, deps in deps_graph.items()} + # Start with nodes that have no dependencies, in original order + queue = [i for i in range(len(statements)) if in_degree[i] == 0] + result: list[cst.BaseStatement] = [] + + while queue: + # Take from front (maintains original order for independent items) + idx = queue.pop(0) + result.append(statements[idx]) + + # Find nodes that depend on this one and reduce their in-degree + for other_idx in range(len(statements)): + if idx in deps_graph[other_idx]: + in_degree[other_idx] -= 1 + if in_degree[other_idx] == 0: + queue.append(other_idx) + + # If we couldn't sort all (cycle), return original order + if len(result) != len(statements): + # Find the statements involved in the cycle for a more helpful warning + unsorted_indices = [i for i in range(len(statements)) if in_degree[i] > 0] + cycle_names = [] + for idx in unsorted_indices: + names = get_statement_defined_names(statements[idx]) + if names: + cycle_names.extend(names) + logger.warning( + f"Circular dependency detected among statements defining: {', '.join(cycle_names) or 'unknown'}. " + "Using original statement order. This may cause NameError at runtime." + ) + return statements + + return result + + +class ImportInserter(cst.CSTTransformer): + """Transformer that inserts global statements into a module. + + - Simple statements (assignments, etc.) are inserted after the globals they depend on + - If no dependencies, they are inserted after imports + - Compound statements (for, while, with, try) are inserted at the END of the module + because they may call functions that are defined later in the file + """ + + def __init__(self, global_statements: list[cst.BaseStatement], last_import_line: int) -> None: + super().__init__() + self.global_statements = global_statements + self.last_import_line = last_import_line + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 + if not self.global_statements: + return updated_node + + # Separate simple statements from compound statements + simple_statements = [] + compound_statements = [] + for stmt in self.global_statements: + if isinstance(stmt, cst.SimpleStatementLine): + simple_statements.append(stmt) + else: + # For, While, With, Try, If, etc. are compound statements + compound_statements.append(stmt) + + updated_body = list(updated_node.body) + + # For simple statements, insert after their dependencies or after imports + if simple_statements: + # Sort statements by their interdependencies first + simple_statements = _sort_statements_by_dependencies(simple_statements) + + # Find the base insertion point (after imports) + base_insertion_index = 0 + if self.last_import_line > 0: + for i in range(len(updated_body)): + if i + 1 == self.last_import_line: + base_insertion_index = i + 1 + break + + # Insert each statement at the correct position based on its dependencies + # Track how many statements have been inserted to maintain relative order + for inserted_count, stmt in enumerate(simple_statements): + deps = get_statement_dependencies(stmt) + + # Find the position after the last dependency definition + # Account for already-inserted statements to maintain order + insertion_index = base_insertion_index + inserted_count + for dep_name in deps: + dep_idx = find_last_definition_index(dep_name, updated_body) + if dep_idx >= 0: + # Insert after this dependency + insertion_index = max(insertion_index, dep_idx + 1) + + updated_body.insert(insertion_index, stmt) + + # Append compound statements at the END of the module + # This ensures they run after all function definitions + if compound_statements: + updated_body.extend(compound_statements) + + return updated_node.with_changes(body=updated_body) + + +def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.BaseStatement]]: """Extract global statements from source code.""" module = cst.parse_module(source_code) collector = GlobalStatementCollector() @@ -554,6 +1228,20 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: src_module, new_added_global_statements = extract_global_statements(src_module_code) dst_module, existing_global_statements = extract_global_statements(dst_module_code) + # Collect function definitions from source and destination + src_func_collector = FunctionDefCollector() + src_module.visit(src_func_collector) + + dst_name_collector = FunctionNameCollector() + dst_module.visit(dst_name_collector) + + # Find new function definitions that don't exist in destination + new_function_defs: list[cst.FunctionDef] = [ + src_func_collector.function_defs[func_name] + for func_name in src_func_collector.function_order + if func_name not in dst_name_collector.names + ] + unique_global_statements = [] for stmt in new_added_global_statements: if any( @@ -565,6 +1253,14 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: # Reuse already-parsed dst_module original_module = dst_module + # Insert new function definitions if any + if new_function_defs: + last_import_line = find_last_import_line(dst_module_code) + transformer = FunctionDefInserter(new_function_defs, last_import_line) + modified_module = original_module.visit(transformer) + dst_module_code = modified_module.code + original_module = cst.parse_module(dst_module_code) + # Parse the src_module_code once only (already done above: src_module) # Collect assignments from the new file new_assignment_collector = GlobalAssignmentCollector() diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index fe0ff095e..3a1021d54 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -23,13 +23,25 @@ class PrComment: benchmark_details: Optional[list[BenchmarkDetail]] = None original_async_throughput: Optional[int] = None best_async_throughput: Optional[int] = None + # Optional pre-computed values (used by create-pr CLI command) + precomputed_test_report: Optional[dict[str, dict[str, int]]] = None + precomputed_loop_count: Optional[int] = None def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]: - report_table = { - test_type.to_name(): result - for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items() - if test_type.to_name() - } + # Use precomputed values if available, otherwise compute from TestResults + if self.precomputed_test_report is not None: + report_table = self.precomputed_test_report + else: + report_table = { + test_type.to_name(): result + for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items() + if test_type.to_name() + } + loop_count = ( + self.precomputed_loop_count + if self.precomputed_loop_count is not None + else self.winning_benchmarking_test_results.number_of_loops() + ) result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = { "optimization_explanation": self.optimization_explanation, @@ -39,7 +51,7 @@ def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[B "file_path": self.relative_file_path, "speedup_x": self.speedup_x, "speedup_pct": self.speedup_pct, - "loop_count": self.winning_benchmarking_test_results.number_of_loops(), + "loop_count": loop_count, "report_table": report_table, "benchmark_details": self.benchmark_details if self.benchmark_details else None, } diff --git a/codeflash/main.py b/codeflash/main.py index 31afd0305..67880a7fc 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -32,7 +32,7 @@ def main() -> None: disable_telemetry = pyproject_config.get("disable_telemetry", False) init_sentry(not disable_telemetry, exclude_errors=True) posthog_cf.initialize_posthog(not disable_telemetry) - args.func() + args.func(args) elif args.verify_setup: args = process_pyproject_config(args) init_sentry(not args.disable_telemetry, exclude_errors=True) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index dc5b82923..c269a5d93 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -529,6 +529,7 @@ class OptimizedCandidateSource(str, Enum): REPAIR = "REPAIR" ADAPTIVE = "ADAPTIVE" JIT_REWRITE = "JIT_REWRITE" + AUGMENTED = "AUGMENTED" @dataclass(frozen=True) @@ -955,3 +956,51 @@ def __eq__(self, other: object) -> bool: return False sys.setrecursionlimit(original_recursion_limit) return True + + +class Phase1CandidateResult(BaseModel): + optimization_id: str + source_code: str + explanation: str + speedup_ratio: Optional[float] = None + runtime_ns: Optional[int] = None + is_correct: bool + line_profiler_results: Optional[str] = None + test_failures: Optional[list[str]] = None + test_diffs: Optional[list[dict]] = None + + +class Phase1FunctionResult(BaseModel): + function_name: str + trace_id: str + original_source_code: str + dependency_code: Optional[str] = None + original_runtime_ns: Optional[int] = None + original_line_profiler_results: Optional[str] = None + candidates: list[Phase1CandidateResult] + best_candidate_id: Optional[str] = None + best_speedup_ratio: Optional[float] = None + # PR creation data - captured after best candidate is selected + file_path: Optional[str] = None + existing_tests_source: Optional[str] = None + replay_tests_source: Optional[str] = None + concolic_tests_source: Optional[str] = None + best_candidate_explanation: Optional[str] = None + best_runtime_ns: Optional[int] = None + # Test results summary for PR creation + test_report: Optional[dict[str, dict[str, int]]] = None # test_type_name -> {passed: int, failed: int} + loop_count: Optional[int] = None + + +class Phase1Output(BaseModel): + codeflash_version: str + timestamp: str + python_version: str + functions: list[Phase1FunctionResult] + total_functions: int + successful_optimizations: int + + +class AugmentedPrompts(BaseModel): + system_prompt: str + user_prompt: str diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 697061b9c..0bd211ad8 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -4,16 +4,19 @@ import concurrent.futures import logging import os +import platform import queue import random import subprocess import uuid from collections import defaultdict +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Callable import libcst as cst import sentry_sdk +import yaml from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -81,6 +84,7 @@ AdaptiveOptimizedCandidate, AIServiceAdaptiveOptimizeRequest, AIServiceCodeRepairRequest, + AugmentedPrompts, BestOptimization, CandidateEvaluationContext, CodeOptimizationContext, @@ -92,6 +96,9 @@ OptimizedCandidateResult, OptimizedCandidateSource, OriginalCodeBaseline, + Phase1CandidateResult, + Phase1FunctionResult, + Phase1Output, TestFile, TestFiles, TestingMode, @@ -475,6 +482,117 @@ def __init__( self.adaptive_optimization_counter = 0 # track how many adaptive optimizations we did for each function self.is_numerical_code: bool | None = None + self.augmented_mode = getattr(args, "augmented", False) if args else False + self.augmented_prompt_file = getattr(args, "augmented_prompt_file", None) if args else None + self.augmented_output = getattr(args, "augmented_output", "codeflash_phase1_results.json") if args else None + self.augmented_prompts: AugmentedPrompts | None = None + self.phase1_candidate_results: list[Phase1CandidateResult] = [] + # PR data captured during process_review for Phase1 output + self.phase1_existing_tests: str | None = None + self.phase1_replay_tests: str | None = None + self.phase1_concolic_tests: str | None = None + self.phase1_best_explanation: str | None = None + self.phase1_best_runtime: int | None = None + self.phase1_test_report: dict[str, dict[str, int]] | None = None + self.phase1_loop_count: int | None = None + + def load_augmented_prompts(self) -> AugmentedPrompts | None: + if not self.augmented_prompt_file: + return None + prompt_path = Path(self.augmented_prompt_file) + if not prompt_path.exists(): + logger.error(f"Augmented prompt file not found: {prompt_path}") + return None + with prompt_path.open(encoding="utf-8") as f: + data = yaml.safe_load(f) + if not data or "system_prompt" not in data or "user_prompt" not in data: + logger.error("Augmented prompt file must contain 'system_prompt' and 'user_prompt' keys") + return None + return AugmentedPrompts(system_prompt=data["system_prompt"], user_prompt=data["user_prompt"]) + + def collect_phase1_candidate_result( + self, + candidate: OptimizedCandidate, + eval_ctx: CandidateEvaluationContext, + test_failures: list[str] | None = None, + test_diffs: list[dict] | None = None, + ) -> Phase1CandidateResult: + speedup = eval_ctx.get_speedup_ratio(candidate.optimization_id) + runtime = eval_ctx.get_optimized_runtime(candidate.optimization_id) + is_correct = eval_ctx.is_correct.get(candidate.optimization_id, False) + line_profiler = eval_ctx.optimized_line_profiler_results.get(candidate.optimization_id) + return Phase1CandidateResult( + optimization_id=candidate.optimization_id, + source_code=candidate.source_code.markdown, + explanation=candidate.explanation, + speedup_ratio=speedup, + runtime_ns=int(runtime) if runtime else None, + is_correct=is_correct, + line_profiler_results=line_profiler, + test_failures=test_failures, + test_diffs=test_diffs, + ) + + def get_phase1_function_result( + self, + code_context: CodeOptimizationContext, + original_runtime_ns: int | None, + original_line_profiler_results: str | None, + best_candidate_id: str | None, + best_speedup_ratio: float | None, + ) -> Phase1FunctionResult: + return Phase1FunctionResult( + function_name=self.function_to_optimize.function_name, + trace_id=self.function_trace_id, + original_source_code=code_context.read_writable_code.markdown, + dependency_code=code_context.read_only_context_code if code_context.read_only_context_code else None, + original_runtime_ns=original_runtime_ns, + original_line_profiler_results=original_line_profiler_results, + candidates=self.phase1_candidate_results, + best_candidate_id=best_candidate_id, + best_speedup_ratio=best_speedup_ratio, + # PR creation data + file_path=self.function_to_optimize.file_path.as_posix(), + existing_tests_source=self.phase1_existing_tests, + replay_tests_source=self.phase1_replay_tests, + concolic_tests_source=self.phase1_concolic_tests, + best_candidate_explanation=self.phase1_best_explanation, + best_runtime_ns=self.phase1_best_runtime, + test_report=self.phase1_test_report, + loop_count=self.phase1_loop_count, + ) + + def write_phase1_output(self, function_result: Phase1FunctionResult) -> None: + from codeflash.version import __version__ as codeflash_version + + output_path = Path(self.augmented_output) if self.augmented_output else Path("codeflash_phase1_results.json") + output = Phase1Output( + codeflash_version=codeflash_version, + timestamp=datetime.now(tz=timezone.utc).isoformat(), + python_version=platform.python_version(), + functions=[function_result], + total_functions=1, + successful_optimizations=1 if function_result.best_candidate_id else 0, + ) + with output_path.open("w", encoding="utf-8") as f: + f.write(output.model_dump_json(indent=2)) + logger.info(f"Phase 1 results written to {output_path}") + + def get_augmented_candidates(self, code_context: CodeOptimizationContext) -> list[OptimizedCandidate]: + if not self.augmented_prompts: + logger.error("Augmented prompts not loaded") + return [] + candidates = self.aiservice_client.augmented_optimize( + source_code=code_context.read_writable_code.markdown, + system_prompt=self.augmented_prompts.system_prompt, + user_prompt=self.augmented_prompts.user_prompt, + trace_id=self.function_trace_id, + dependency_code=code_context.read_only_context_code if code_context.read_only_context_code else None, + n_candidates=3, + ) + logger.info(f"Received {len(candidates)} augmented optimization candidates") + return candidates + def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None logger.info(f"!lsp|Function Trace ID: {self.function_trace_id}") @@ -640,6 +758,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: read_only_context_code=code_context.read_only_context_code, run_experiment=should_run_experiment, is_numerical_code=self.is_numerical_code, + code_context=code_context, ) concurrent.futures.wait([future_tests, future_optimizations]) @@ -707,6 +826,23 @@ def optimize_function(self) -> Result[BestOptimization, str]: if self.args.override_fixtures: restore_conftest(original_conftest_content) + + if self.augmented_mode: + function_result = self.get_phase1_function_result( + code_context=code_context, + original_runtime_ns=original_code_baseline.runtime if original_code_baseline else None, + original_line_profiler_results=original_code_baseline.line_profile_results.get("str_out") + if original_code_baseline and original_code_baseline.line_profile_results + else None, + best_candidate_id=best_optimization.candidate.optimization_id if best_optimization else None, + best_speedup_ratio=performance_gain( + original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime + ) + if best_optimization + else None, + ) + self.write_phase1_output(function_result) + if not best_optimization: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) @@ -1221,6 +1357,11 @@ def determine_best_candidate( exp_type=exp_type, ) + if self.augmented_mode: + for candidate in candidates: + phase1_result = self.collect_phase1_candidate_result(candidate=candidate, eval_ctx=eval_ctx) + self.phase1_candidate_results.append(phase1_result) + return best_optimization def call_adaptive_optimize( @@ -1627,8 +1768,31 @@ def generate_optimizations( read_only_context_code: str, run_experiment: bool = False, # noqa: FBT001, FBT002 is_numerical_code: bool | None = None, # noqa: FBT001 + code_context: CodeOptimizationContext | None = None, ) -> Result[tuple[OptimizationSet, str], str]: """Generate optimization candidates for the function. Backend handles multi-model diversity.""" + if self.augmented_mode and self.augmented_prompt_file: + self.augmented_prompts = self.load_augmented_prompts() + if not self.augmented_prompts: + return Failure("Failed to load augmented prompts from file") + if code_context is None: + return Failure("Code context required for augmented mode") + augmented_candidates = self.get_augmented_candidates(code_context) + if not augmented_candidates: + return Failure( + f"/!\\ NO AUGMENTED OPTIMIZATIONS GENERATED for {self.function_to_optimize.function_name}" + ) + future_references = self.executor.submit( + get_opt_review_metrics, + self.function_to_optimize_source_code, + self.function_to_optimize.file_path, + self.function_to_optimize.qualified_name, + self.project_root, + self.test_cfg.tests_root, + ) + function_references = future_references.result() + return Success((OptimizationSet(control=augmented_candidates, experiment=None), function_references)) + n_candidates = get_effort_value(EffortKeys.N_OPTIMIZER_CANDIDATES, self.effort) future_optimization_candidates = self.executor.submit( self.aiservice_client.optimize_python_code, @@ -1987,6 +2151,21 @@ def process_review( best_optimization.explanation_v2 = new_explanation.explanation_message() + # Capture PR data for Phase1 output in augmented mode + if self.augmented_mode: + self.phase1_existing_tests = existing_tests + self.phase1_replay_tests = replay_tests + self.phase1_concolic_tests = concolic_tests + self.phase1_best_explanation = new_explanation.explanation_message() + self.phase1_best_runtime = best_optimization.runtime + # Capture test results for PR creation + self.phase1_test_report = { + tt.to_name(): counts + for tt, counts in new_explanation.winning_behavior_test_results.get_test_pass_fail_report_by_type().items() + if tt.to_name() # Skip empty names + } + self.phase1_loop_count = new_explanation.winning_benchmarking_test_results.number_of_loops() + data = { "original_code": original_code_combined, "new_code": new_code_combined, diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index f888f710a..da3ca1a72 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -186,6 +186,8 @@ def check_create_pr( root_dir: Path, git_remote: Optional[str] = None, optimization_review: str = "", + precomputed_test_report: Optional[dict[str, dict[str, int]]] = None, + precomputed_loop_count: Optional[int] = None, ) -> None: pr_number: Optional[int] = env_utils.get_pr_number() git_repo = git.Repo(search_parent_directories=True) @@ -222,6 +224,8 @@ def check_create_pr( benchmark_details=explanation.benchmark_details, original_async_throughput=explanation.original_async_throughput, best_async_throughput=explanation.best_async_throughput, + precomputed_test_report=precomputed_test_report, + precomputed_loop_count=precomputed_loop_count, ), existing_tests=existing_tests_source, generated_tests=generated_original_test_source, @@ -274,6 +278,8 @@ def check_create_pr( benchmark_details=explanation.benchmark_details, original_async_throughput=explanation.original_async_throughput, best_async_throughput=explanation.best_async_throughput, + precomputed_test_report=precomputed_test_report, + precomputed_loop_count=precomputed_loop_count, ), existing_tests=existing_tests_source, generated_tests=generated_original_test_source, diff --git a/pyproject.toml b/pyproject.toml index 1714532d0..9c6d8f312 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "codeflash-benchmark", "filelock", "pytest-asyncio>=1.2.0", + "pyyaml>=6.0.3", ] [project.urls] @@ -76,6 +77,7 @@ dev = [ "types-unidiff>=0.7.0.20240505,<0.8", "uv>=0.6.2", "pre-commit>=4.2.0,<5", + "ty>=0.0.12", ] tests = [ "black>=25.9.0", diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index da83146a8..65967a9b9 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -3742,3 +3742,1736 @@ def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]: project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == expected + + +def test_global_assignments_with_function_dependencies() -> None: + """Test that global assignments that depend on functions are inserted after those functions. + + This tests the fix for a bug where LLM-generated optimizations that use module-level + code like `_TABLE = unicode_to_char(...)` would fail because the assignment was inserted + before the function definition. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original file: standardize_quotes first, unicode_to_char second + original_code = '''def standardize_quotes(text: str) -> str: + """Standardize quotes in text.""" + return text + +def unicode_to_char(unicode_val: str) -> str: + """Convert unicode value to char.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimized code: defines unicode_to_char first, then module-level code that uses it + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Convert unicode value to char.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +_CODES = ("U+0022", "U+201C") +_TRANSLATION_TABLE = {ord(unicode_to_char(c)): ord('"') for c in _CODES} + +def standardize_quotes(text: str) -> str: + """Standardize quotes in text.""" + return text.translate(_TRANSLATION_TABLE) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The assignment that depends on unicode_to_char should be inserted AFTER unicode_to_char + # not after imports (which would cause a NameError) + + # Parse the result and verify the order + import libcst as cst + + module = cst.parse_module(result) + + # Find positions of key elements + unicode_to_char_pos = None + translation_table_pos = None + + for i, stmt in enumerate(module.body): + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == "unicode_to_char": + unicode_to_char_pos = i + elif isinstance(stmt, cst.SimpleStatementLine): + for child in stmt.body: + if isinstance(child, cst.Assign): + for target in child.targets: + if isinstance(target.target, cst.Name) and target.target.value == "_TRANSLATION_TABLE": + translation_table_pos = i + + # Verify that _TRANSLATION_TABLE comes AFTER unicode_to_char + assert unicode_to_char_pos is not None, "unicode_to_char function not found in result" + assert translation_table_pos is not None, "_TRANSLATION_TABLE assignment not found in result" + assert translation_table_pos > unicode_to_char_pos, ( + f"_TRANSLATION_TABLE (pos {translation_table_pos}) should be after " + f"unicode_to_char (pos {unicode_to_char_pos}) because it depends on it" + ) + + +def test_global_assignments_inside_for_loops_not_extracted() -> None: + """Test that assignments inside for-loops are NOT extracted as standalone globals. + + This tests the fix for a bug where LLM-generated optimizations that build + translation tables using for-loops would have the loop body assignments + incorrectly extracted, causing NameError for loop variables. + """ + from codeflash.code_utils.code_extractor import GlobalAssignmentCollector + + import libcst as cst + + # Code with assignments inside a for-loop (common optimization pattern) + # Note: Using regular assignment (not annotated) since GlobalAssignmentCollector only handles Assign + code_with_for_loop = ''' +double_quotes = {"a": "U+0022", "b": "U+201C"} + +_QUOTE_TRANSLATION = {} +for unicode_val in double_quotes.values(): + ch = unicode_to_char(unicode_val) + _QUOTE_TRANSLATION[ord(ch)] = _double_quote_standard + +def standardize_quotes(text: str) -> str: + return text.translate(_QUOTE_TRANSLATION) +''' + + module = cst.parse_module(code_with_for_loop) + collector = GlobalAssignmentCollector() + module.visit(collector) + + # Only the top-level assignments should be collected, NOT the one inside the for-loop + # _QUOTE_TRANSLATION = {} is at module level (should be collected) + # _QUOTE_TRANSLATION[ord(ch)] = ... is inside for-loop (should NOT be collected) + # double_quotes = {...} is at module level (should be collected) + assert "double_quotes" in collector.assignments, "double_quotes should be collected" + assert "_QUOTE_TRANSLATION" in collector.assignments, "_QUOTE_TRANSLATION init should be collected" + + # The assignment inside the for-loop uses subscript, not Name, so it wouldn't be + # collected anyway. But let's verify the collector doesn't crash and works correctly. + # More importantly, verify that simple name assignments inside loops are NOT collected. + + code_with_simple_assignment_in_loop = ''' +result = {} +for item in items: + key = process(item) + result[key] = item +''' + module2 = cst.parse_module(code_with_simple_assignment_in_loop) + collector2 = GlobalAssignmentCollector() + module2.visit(collector2) + + assert "result" in collector2.assignments, "result should be collected (top-level)" + assert "key" not in collector2.assignments, "key should NOT be collected (inside for-loop)" + + +def test_global_assignments_inside_while_loops_not_extracted() -> None: + """Test that assignments inside while-loops are NOT extracted.""" + from codeflash.code_utils.code_extractor import GlobalAssignmentCollector + + import libcst as cst + + code_with_while_loop = ''' +counter = 0 +while counter < 10: + value = compute(counter) + counter += 1 +''' + module = cst.parse_module(code_with_while_loop) + collector = GlobalAssignmentCollector() + module.visit(collector) + + assert "counter" in collector.assignments, "counter should be collected (top-level)" + assert "value" not in collector.assignments, "value should NOT be collected (inside while-loop)" + + +def test_global_assignments_inside_with_blocks_not_extracted() -> None: + """Test that assignments inside with-blocks are NOT extracted.""" + from codeflash.code_utils.code_extractor import GlobalAssignmentCollector + + import libcst as cst + + code_with_with_block = ''' +config = {} +with open("file.txt") as f: + content = f.read() + data = parse(content) +''' + module = cst.parse_module(code_with_with_block) + collector = GlobalAssignmentCollector() + module.visit(collector) + + assert "config" in collector.assignments, "config should be collected (top-level)" + assert "content" not in collector.assignments, "content should NOT be collected (inside with-block)" + assert "data" not in collector.assignments, "data should NOT be collected (inside with-block)" + + +def test_global_assignments_inside_try_blocks_not_extracted() -> None: + """Test that assignments inside try-blocks are NOT extracted.""" + from codeflash.code_utils.code_extractor import GlobalAssignmentCollector + + import libcst as cst + + code_with_try_block = ''' +default = "fallback" +try: + result = risky_operation() + processed = transform(result) +except Exception: + pass +''' + module = cst.parse_module(code_with_try_block) + collector = GlobalAssignmentCollector() + module.visit(collector) + + assert "default" in collector.assignments, "default should be collected (top-level)" + assert "result" not in collector.assignments, "result should NOT be collected (inside try-block)" + assert "processed" not in collector.assignments, "processed should NOT be collected (inside try-block)" + + +def test_global_assignment_transformer_ignores_loop_assignments() -> None: + """Test that GlobalAssignmentTransformer doesn't replace assignments inside loops. + + The transformer should: + 1. NOT replace assignments inside for/while/with/try blocks + 2. Still add new top-level assignments that weren't in the original + """ + from codeflash.code_utils.code_extractor import GlobalAssignmentTransformer + + import libcst as cst + + # Original code has 'key' inside a for-loop and 'result' at top level + original_code = ''' +result = {} +for item in items: + key = process(item) + result[key] = item +''' + # New assignments - 'result' should replace top-level, 'key' should NOT replace loop var + new_assignments = { + "result": cst.parse_statement("result = {'new': 'dict'}").body[0], + "key": cst.parse_statement("key = 'new_value'").body[0], + } + + module = cst.parse_module(original_code) + transformer = GlobalAssignmentTransformer(new_assignments, ["result", "key"]) + result_module = module.visit(transformer) + + result_code = result_module.code + + # The 'key' inside the for-loop should NOT be replaced + assert "key = process(item)" in result_code, "Assignment inside for-loop should not be transformed" + + # The top-level 'result' SHOULD be replaced + assert "result = {'new': 'dict'}" in result_code, "Top-level assignment should be replaced" + assert "result = {}" not in result_code, "Original top-level assignment should be gone" + + +def test_add_global_assignments_with_loop_variables() -> None: + """Test that add_global_assignments doesn't extract assignments that reference loop variables. + + This is the specific bug case from optimization runs where code like: + for uval in double_quotes.values(): + ch = unicode_to_char(uval) + _translation_map[ord(ch)] = _double_quote_ord + + Would have 'ch = unicode_to_char(uval)' extracted and inserted at module level, + causing 'NameError: name 'uval' is not defined'. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original simple function + original_code = '''def standardize_quotes(text: str) -> str: + return text + +def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimized code with for-loop that builds translation table at module level + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) + +double_quotes = {"a": "U+0022", "b": "U+201C"} +_translation_map = {} + +for uval in double_quotes.values(): + ch = unicode_to_char(uval) + _translation_map[ord(ch)] = ord('"') + +def standardize_quotes(text: str) -> str: + return text.translate(_translation_map) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when compiled + # If 'ch = unicode_to_char(uval)' was incorrectly extracted, it would cause + # NameError because 'uval' wouldn't be defined outside the for-loop + try: + compile(result, "", "exec") + except NameError as e: + raise AssertionError(f"Generated code has NameError (loop var extracted incorrectly): {e}") from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the for-loop structure is preserved (ch assignment inside loop) + assert "for uval in" in result or "for unicode_val in" in result or "ch" not in result, ( + "If ch is in result, the for-loop should also be present" + ) + + +def test_add_global_assignments_with_unicode_val_loop_variable() -> None: + """Test that add_global_assignments correctly transfers for-loops using 'unicode_val' as loop variable. + + This is the specific bug case: NameError: name 'unicode_val' is not defined + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''def standardize_quotes(text: str) -> str: + return text +''' + + # Optimized code with for-loop using unicode_val as the loop variable + optimized_code = '''single_quotes = {"U+0027": "'", "U+2018": "'", "U+2019": "'"} +_translation_map = {} + +for unicode_val in single_quotes: + ch = chr(int(unicode_val.replace("U+", ""), 16)) + _translation_map[ord(ch)] = ord("'") + +def standardize_quotes(text: str) -> str: + return text.translate(_translation_map) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when compiled + try: + compile(result, "", "exec") + except NameError as e: + raise AssertionError(f"Generated code has NameError (unicode_val extracted incorrectly): {e}") from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the for-loop is present + assert "for unicode_val in" in result, "For-loop with unicode_val should be transferred" + + +def test_add_global_assignments_with_ch_loop_variable() -> None: + """Test that add_global_assignments correctly transfers for-loops using 'ch' as loop variable. + + This is the specific bug case: NameError: name 'ch' is not defined + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''def normalize_text(text: str) -> str: + return text +''' + + # Optimized code with for-loop using ch as the loop variable + optimized_code = '''replacements = {"a": "A", "b": "B"} +_char_map = {} + +for ch in replacements: + _char_map[ord(ch)] = ord(replacements[ch]) + +def normalize_text(text: str) -> str: + return text.translate(_char_map) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when compiled + try: + compile(result, "", "exec") + except NameError as e: + raise AssertionError(f"Generated code has NameError (ch extracted incorrectly): {e}") from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the for-loop is present + assert "for ch in" in result, "For-loop with ch should be transferred" + + +def test_add_global_assignments_with_helper_function_call() -> None: + """Test that add_global_assignments transfers assignments that call helper functions. + + Note: add_global_assignments only handles assignments, not function definitions. + Function definitions are transferred via replace_functions_in_file separately. + + This test verifies that assignments calling helper functions are correctly transferred. + The actual NameError: name '_build_quote_translation_table' is not defined would occur + if the function wasn't transferred in the full replacement flow. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original code that already has the helper function defined + original_code = '''def _build_quote_translation_table(): + table = {} + for i in range(128): + table[i] = i + return table + +def standardize_quotes(text: str) -> str: + return text +''' + + # Optimized code with an assignment that calls the helper function + optimized_code = '''def _build_quote_translation_table(): + table = {} + for i in range(128): + table[i] = i + return table + +_QUOTE_TABLE = _build_quote_translation_table() + +def standardize_quotes(text: str) -> str: + return text.translate(_QUOTE_TABLE) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when compiled + try: + compile(result, "", "exec") + except NameError as e: + raise AssertionError(f"Generated code has NameError: {e}") from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the assignment is present and placed AFTER the function definition + assert "_QUOTE_TABLE = _build_quote_translation_table()" in result, "Assignment should be transferred" + # Verify the function is still present (it was already in original) + assert "def _build_quote_translation_table" in result, "Helper function should be preserved" + + # Verify correct ordering: function must come before the assignment that uses it + func_pos = result.index("def _build_quote_translation_table") + assign_pos = result.index("_QUOTE_TABLE = _build_quote_translation_table()") + assert func_pos < assign_pos, "Function definition must come before assignment that calls it" + + +def test_add_global_assignments_forloop_calls_function_defined_later() -> None: + """Test that for-loops calling functions are placed AFTER those function definitions. + + This is the specific bug case: NameError: name 'unicode_to_char' is not defined + + When the original file has a function defined later (e.g., after standardize_quotes), + and the optimized code adds a for-loop that calls that function, the for-loop must + be placed AFTER the function definition, not at the top of the file. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original code where unicode_to_char is defined AFTER the main function + # This mirrors the real-world case where the helper is at the bottom + original_code = '''def standardize_quotes(text: str) -> str: + return text + + +def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimized code with a for-loop that calls unicode_to_char at module level + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) + +_DOUBLE_QUOTE_UNICODE_VALUES = ["U+0022", "U+201C", "U+201D"] +_TRANSLATION_TABLE = {} + +for code in _DOUBLE_QUOTE_UNICODE_VALUES: + ch = unicode_to_char(code) + _TRANSLATION_TABLE[ord(ch)] = ord('"') + +def standardize_quotes(text: str) -> str: + return text.translate(_TRANSLATION_TABLE) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when executed + try: + compile(result, "", "exec") + # Also try to actually execute it to catch runtime NameErrors + exec(compile(result, "", "exec"), {}) + except NameError as e: + raise AssertionError( + f"Generated code has NameError (for-loop placed before function): {e}\n\nGenerated code:\n{result}" + ) from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the for-loop is present + assert "for code in _DOUBLE_QUOTE_UNICODE_VALUES" in result, "For-loop should be transferred" + + # Verify correct ordering: function must come before the for-loop that calls it + func_pos = result.index("def unicode_to_char") + forloop_pos = result.index("for code in _DOUBLE_QUOTE_UNICODE_VALUES") + assert func_pos < forloop_pos, ( + f"Function definition must come before for-loop that calls it.\n" + f"Function at position {func_pos}, for-loop at position {forloop_pos}\n" + f"Generated code:\n{result}" + ) + + +def test_add_global_assignments_forloop_uses_computed_variable() -> None: + """Test that for-loops are placed after variables they depend on. + + This is the specific bug case: NameError: name '_double_chars' is not defined + + When optimized code computes a variable and then uses it in a for-loop, + the assignment must come before the for-loop. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''def process_text(text: str) -> str: + return text +''' + + # Optimized code where _double_chars is computed and then used in a for-loop + optimized_code = '''_UNICODE_VALUES = ["U+0022", "U+201C"] +_double_chars = tuple(chr(int(u.replace("U+", ""), 16)) for u in _UNICODE_VALUES) + +_TRANSLATION = {} +for ch in _double_chars: + _TRANSLATION[ord(ch)] = ord('"') + +def process_text(text: str) -> str: + return text.translate(_TRANSLATION) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when executed + try: + compile(result, "", "exec") + exec(compile(result, "", "exec"), {}) + except NameError as e: + raise AssertionError( + f"Generated code has NameError (variable not defined before for-loop): {e}\n\nGenerated code:\n{result}" + ) from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the for-loop is present + assert "for ch in _double_chars" in result, "For-loop should be transferred" + + # Verify correct ordering: _double_chars assignment must come before the for-loop + assign_pos = result.index("_double_chars = ") + forloop_pos = result.index("for ch in _double_chars") + assert assign_pos < forloop_pos, ( + f"Variable assignment must come before for-loop that uses it.\n" + f"Assignment at position {assign_pos}, for-loop at position {forloop_pos}\n" + f"Generated code:\n{result}" + ) + + +def test_add_global_assignments_multiple_forloops_with_dependencies() -> None: + """Test that multiple for-loops with function dependencies are ordered correctly. + + This tests the real-world case from standardize_quotes optimization where: + 1. unicode_to_char function is used + 2. Multiple for-loops build translation tables + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original code with function defined after main function + original_code = '''def standardize_quotes(text: str) -> str: + return text + + +def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimized code with multiple for-loops that depend on unicode_to_char + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) + +double_quotes = {"U+0022": '"', "U+201C": '"'} +single_quotes = {"U+0027": "'", "U+2018": "'"} + +_translation_table = {} + +for unicode_val in double_quotes: + ch = unicode_to_char(unicode_val) + _translation_table[ord(ch)] = ord('"') + +for unicode_val in single_quotes: + ch = unicode_to_char(unicode_val) + _translation_table[ord(ch)] = ord("'") + +def standardize_quotes(text: str) -> str: + return text.translate(_translation_table) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when executed + try: + compile(result, "", "exec") + exec(compile(result, "", "exec"), {}) + except NameError as e: + raise AssertionError( + f"Generated code has NameError: {e}\n\nGenerated code:\n{result}" + ) from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify both for-loops are present + assert "for unicode_val in double_quotes" in result, "First for-loop should be transferred" + assert "for unicode_val in single_quotes" in result, "Second for-loop should be transferred" + + # Verify correct ordering: function must come before all for-loops + func_pos = result.index("def unicode_to_char") + first_forloop_pos = result.index("for unicode_val in double_quotes") + second_forloop_pos = result.index("for unicode_val in single_quotes") + + assert func_pos < first_forloop_pos, "Function must come before first for-loop" + assert func_pos < second_forloop_pos, "Function must come before second for-loop" + + +def test_add_global_assignments_variable_depends_on_existing_global() -> None: + """Test that new global assignments depending on existing globals are inserted after them. + + This tests the fix for a bug where LLM-generated optimizations that add module-level + cache variables like `_TRANSLATION_CACHE: dict = {None: tbl}` would fail because the + assignment was inserted after imports but BEFORE the `tbl` variable it depends on. + + Real-world example from unstructured/cleaners/core.py optimization: + - Original has: `tbl = dict.fromkeys(...)` + - Optimized adds: `_TRANSLATION_CACHE: dict = {None: tbl}` (depends on tbl) + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original file with imports and module-level variable + original_code = '''from __future__ import annotations +import sys +import unicodedata + +tbl = dict.fromkeys( + i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P") +) + +def remove_sentence_punctuation(s: str) -> str: + tbl_new = tbl.copy() + return s.translate(tbl_new) +''' + + # Optimized code adds a cache that depends on `tbl` + optimized_code = '''from __future__ import annotations +import sys +import unicodedata + +tbl = dict.fromkeys( + i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P") +) + +# Cache for translation tables +_TRANSLATION_CACHE: dict = {None: tbl} + +def remove_sentence_punctuation(s: str) -> str: + tbl_new = tbl.copy() + return s.translate(tbl_new) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when executed + try: + compile(result, "", "exec") + exec(compile(result, "", "exec"), {}) + except NameError as e: + raise AssertionError( + f"Generated code has NameError: {e}\n\nGenerated code:\n{result}" + ) from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify `_TRANSLATION_CACHE` is in the result + assert "_TRANSLATION_CACHE" in result, "_TRANSLATION_CACHE should be in the result" + + # Verify correct ordering: `tbl` must come before `_TRANSLATION_CACHE` + tbl_pos = result.index("tbl = dict.fromkeys") + cache_pos = result.index("_TRANSLATION_CACHE") + + assert cache_pos > tbl_pos, ( + f"_TRANSLATION_CACHE (pos {cache_pos}) should be after " + f"tbl (pos {tbl_pos}) because it depends on it" + ) + + +def test_add_global_assignments_tuple_unpacking() -> None: + """Test that tuple unpacking assignments are properly tracked. + + Verifies the fix for: a, b = 1, 2 where the target is a Tuple, not a Name. + Without the fix, neither 'a' nor 'b' would be tracked as defined. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''import sys + +def foo(): + pass +''' + + # Optimized code with tuple unpacking that a later variable depends on + optimized_code = '''import sys + +a, b = 1, 2 +c = a + b + +def foo(): + pass +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + exec(compiled, {}) + except NameError as e: + msg = f"Tuple unpacking test failed with NameError: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e + + # Verify correct ordering: a, b = 1, 2 must come before c = a + b + unpack_pos = result.index("a, b = 1, 2") + c_pos = result.index("c = a + b") + assert unpack_pos < c_pos, "Tuple unpacking must come before dependent assignment" + + +def test_add_global_assignments_chained_assignment() -> None: + """Test that chained assignments are properly tracked. + + Verifies the fix for: a = b = c = 5 which has multiple targets. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''import sys + +def foo(): + pass +''' + + # Optimized code with chained assignment that a later variable depends on + optimized_code = '''import sys + +a = b = c = 5 +d = a + b + c + +def foo(): + pass +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + exec(compiled, {}) + except NameError as e: + msg = f"Chained assignment test failed with NameError: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e + + # Verify correct ordering: a = b = c = 5 must come before d = a + b + c + chain_pos = result.index("a = b = c = 5") + d_pos = result.index("d = a + b + c") + assert chain_pos < d_pos, "Chained assignment must come before dependent assignment" + + +def test_add_global_assignments_multiple_new_statements() -> None: + """Test that multiple new statements maintain correct order. + + When inserting multiple statements with no dependencies, they should + maintain their relative order from the optimized code. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''import sys + +def foo(): + pass +''' + + # Optimized code with multiple independent global assignments + optimized_code = '''import sys + +FIRST = 1 +SECOND = 2 +THIRD = 3 + +def foo(): + pass +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid + try: + compiled = compile(result, "", "exec") + exec(compiled, {}) + except (NameError, SyntaxError) as e: + msg = f"Multiple statements test failed: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e + + # Verify correct ordering: FIRST, SECOND, THIRD in order + first_pos = result.index("FIRST = 1") + second_pos = result.index("SECOND = 2") + third_pos = result.index("THIRD = 3") + assert first_pos < second_pos < third_pos, ( + f"Statements should maintain order: FIRST ({first_pos}) < SECOND ({second_pos}) < THIRD ({third_pos})" + ) + + +def test_add_global_assignments_annotated_no_spurious_deps() -> None: + """Test that type annotations don't create spurious dependencies. + + Verifies the fix for: x: Tuple[int, int] = value + Without the fix, Tuple and int would be added as spurious dependencies. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''import sys + +val = (1, 2) + +def foo(): + pass +''' + + # Optimized code with type annotation - the annotation uses Tuple[int, int] + # but the actual value only depends on 'val' + optimized_code = '''import sys + +val = (1, 2) +x: tuple[int, int] = val +y = x + +def foo(): + pass +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + exec(compiled, {}) + except NameError as e: + msg = f"Annotated assignment test failed with NameError: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e + + # Verify x assignment is in the result + assert "x: tuple[int, int] = val" in result, "Annotated assignment should be present" + + # Verify correct ordering: val must come before x, x must come before y + val_pos = result.index("val = (1, 2)") + x_pos = result.index("x: tuple[int, int] = val") + y_pos = result.index("y = x") + assert val_pos < x_pos < y_pos, "Variables should be in dependency order" + + +# ============================================================================= +# Real-world standardize_quotes optimization tests +# These tests verify the fixes work for the actual optimization scenarios +# ============================================================================= + + +def test_standardize_quotes_optimization_candidate_6_pattern() -> None: + """Test optimization pattern from candidate 6 that caused NameError: name '_double_chars' is not defined. + + This pattern uses tuple comprehensions at module level that depend on unicode_to_char, + then uses those tuples in for-loops to build translation tables. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original standardize_quotes code structure + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + double_quotes = { + '"': "U+0022", + '"': "U+201C", + '"': "U+201D", + } + single_quotes = { + "'": "U+0027", + "'": "U+2018", + "'": "U+2019", + } + + double_quote_standard = '"' + single_quote_standard = "'" + + for unicode_val in double_quotes.values(): + unicode_char = unicode_to_char(unicode_val) + if unicode_char in text: + text = text.replace(unicode_char, double_quote_standard) + + for unicode_val in single_quotes.values(): + unicode_char = unicode_to_char(unicode_val) + if unicode_char in text: + text = text.replace(unicode_char, single_quote_standard) + + return text + + +def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimization candidate 6 pattern: precompute chars using tuple comprehension + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +double_quotes = { + '"': "U+0022", + '"': "U+201C", + '"': "U+201D", +} +single_quotes = { + "'": "U+0027", + "'": "U+2018", + "'": "U+2019", +} + +_double_chars = tuple(unicode_to_char(u) for u in double_quotes.values()) +_single_chars = tuple(unicode_to_char(u) for u in single_quotes.values()) + +_QUOTE_TRANSLATION: dict[int, str] = {} +_double_quote_standard = '"' +_single_quote_standard = "'" + +for _ch in _double_chars: + _QUOTE_TRANSLATION[ord(_ch)] = _double_quote_standard + +for _ch in _single_chars: + _QUOTE_TRANSLATION[ord(_ch)] = _single_quote_standard + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_QUOTE_TRANSLATION) +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + exec(compiled, {}) + except NameError as e: + raise AssertionError( + f"Candidate 6 pattern failed with NameError: {e}\n\nGenerated code:\n{result}" + ) from e + + # Verify key elements are present + assert "_double_chars = tuple" in result, "_double_chars assignment should be present" + assert "for _ch in _double_chars" in result, "First for-loop should be present" + assert "for _ch in _single_chars" in result, "Second for-loop should be present" + + # Verify ordering: unicode_to_char must come before _double_chars assignment + func_pos = result.index("def unicode_to_char") + double_chars_pos = result.index("_double_chars = tuple") + assert func_pos < double_chars_pos, "unicode_to_char must be defined before _double_chars" + + +def test_standardize_quotes_optimization_candidate_7_pattern() -> None: + """Test optimization pattern from candidate 7 that caused NameError: name 'unicode_to_char' is not defined. + + This pattern moves quote dictionaries to module level and builds translation + table using for-loops that call unicode_to_char. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original code + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text + + +def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimization candidate 7 pattern: module-level for-loops calling unicode_to_char + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +_DOUBLE_QUOTE_UNICODE_VALUES = [ + "U+0022", + "U+201C", + "U+201D", +] + +_SINGLE_QUOTE_UNICODE_VALUES = [ + "U+0027", + "U+2018", + "U+2019", +] + +_QUOTE_TRANSLATION_TABLE = {} +_double_replacement_ord = ord('"') +_single_replacement_ord = ord("'") + +for u in _DOUBLE_QUOTE_UNICODE_VALUES: + ch = unicode_to_char(u) + _QUOTE_TRANSLATION_TABLE[ord(ch)] = _double_replacement_ord + +for u in _SINGLE_QUOTE_UNICODE_VALUES: + ch = unicode_to_char(u) + _QUOTE_TRANSLATION_TABLE[ord(ch)] = _single_replacement_ord + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_QUOTE_TRANSLATION_TABLE) +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + namespace = {} + exec(compiled, namespace) + + # Verify the translation table is correctly populated by the for-loops + # This verifies that the for-loops execute correctly and call unicode_to_char + translation_table = namespace.get("_QUOTE_TRANSLATION_TABLE") + assert translation_table is not None, "_QUOTE_TRANSLATION_TABLE should be defined" + # 8220 = U+201C (left double quote), 8221 = U+201D (right double quote) + # 34 = ord('"') = ASCII double quote + assert 8220 in translation_table, "Left double quote (U+201C) should be in table" + assert 8221 in translation_table, "Right double quote (U+201D) should be in table" + assert translation_table[8220] == 34, "Left double quote should map to ASCII quote" + assert translation_table[8221] == 34, "Right double quote should map to ASCII quote" + + except NameError as e: + raise AssertionError( + f"Candidate 7 pattern failed with NameError: {e}\n\nGenerated code:\n{result}" + ) from e + + # Verify for-loops are present and ordered correctly + assert "for u in _DOUBLE_QUOTE_UNICODE_VALUES" in result + assert "for u in _SINGLE_QUOTE_UNICODE_VALUES" in result + + func_pos = result.index("def unicode_to_char") + first_loop_pos = result.index("for u in _DOUBLE_QUOTE_UNICODE_VALUES") + assert func_pos < first_loop_pos, "unicode_to_char must be defined before for-loops" + + +def test_standardize_quotes_optimization_candidate_9_pattern() -> None: + """Test optimization pattern from candidate 9 that caused NameError: name 'unicode_to_char' is not defined. + + This pattern is similar to candidate 7 but with slightly different variable names. + It builds a translation table at module level using for-loops. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original code with unicode_to_char defined AFTER standardize_quotes + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + double_quotes = {'"': "U+0022", '"': "U+201C"} + single_quotes = {"'": "U+0027", "'": "U+2018"} + + for unicode_val in double_quotes.values(): + unicode_char = unicode_to_char(unicode_val) + if unicode_char in text: + text = text.replace(unicode_char, '"') + + for unicode_val in single_quotes.values(): + unicode_char = unicode_to_char(unicode_val) + if unicode_char in text: + text = text.replace(unicode_char, "'") + + return text + + +def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimization candidate 9 pattern + # Use proper Unicode escape sequences for the curly quote characters as keys + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +double_quotes = {'"': "U+0022", '\u201c': "U+201C", '\u201d': "U+201D"} +single_quotes = {"'": "U+0027", '\u2018': "U+2018", '\u2019': "U+2019"} + +_translation_table = {} + +_double_standard_ord = ord('"') +for unicode_val in double_quotes.values(): + ch = unicode_to_char(unicode_val) + _translation_table[ord(ch)] = _double_standard_ord + +_single_standard_ord = ord("'") +for unicode_val in single_quotes.values(): + ch = unicode_to_char(unicode_val) + _translation_table[ord(ch)] = _single_standard_ord + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_translation_table) +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + namespace = {} + exec(compiled, namespace) + + # Verify the translation table is correctly populated by the for-loops + # This verifies that the for-loops execute correctly and call unicode_to_char + translation_table = namespace.get("_translation_table") + assert translation_table is not None, "_translation_table should be defined" + # 8220 = U+201C (left double quote), 8221 = U+201D (right double quote) + # 8216 = U+2018 (left single quote), 8217 = U+2019 (right single quote) + assert 8220 in translation_table, "Left double quote (U+201C) should be in table" + assert 8221 in translation_table, "Right double quote (U+201D) should be in table" + assert 8216 in translation_table, "Left single quote (U+2018) should be in table" + assert 8217 in translation_table, "Right single quote (U+2019) should be in table" + # Verify mappings to ASCII quotes + assert translation_table[8220] == 34, "Left double quote should map to ASCII double quote" + assert translation_table[8221] == 34, "Right double quote should map to ASCII double quote" + assert translation_table[8216] == 39, "Left single quote should map to ASCII single quote" + assert translation_table[8217] == 39, "Right single quote should map to ASCII single quote" + + except NameError as e: + raise AssertionError( + f"Candidate 9 pattern failed with NameError: {e}\n\nGenerated code:\n{result}" + ) from e + + # Verify ordering: unicode_to_char must come before module-level for-loops + # Use \nfor to find module-level (unindented) for-loops, not those inside functions + func_pos = result.index("def unicode_to_char") + first_loop = result.index("\nfor unicode_val in double_quotes.values()") + second_loop = result.index("\nfor unicode_val in single_quotes.values()") + + assert func_pos < first_loop, "unicode_to_char must be defined before first for-loop" + assert func_pos < second_loop, "unicode_to_char must be defined before second for-loop" + assert first_loop < second_loop, "For-loops should maintain their relative order" + + +def test_standardize_quotes_testgen_postprocessing_with_translation_table() -> None: + """Test end-to-end postprocessing pipeline for standardize_quotes with translation table pattern. + + This simulates the testgen endpoint returning generated tests for an optimization + that uses module-level for-loops to build a translation table, and verifies + the postprocessing pipeline (add_global_assignments + test postprocessing) works correctly. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests + from codeflash.models.models import GeneratedTests, GeneratedTestsList + + # Original code (what we're optimizing) + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text + + +def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Simulated testgen optimization response with module-level for-loops + # This pattern builds a translation table at module level + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +_DOUBLE_QUOTE_CODES = ["U+0022", "U+201C", "U+201D"] +_SINGLE_QUOTE_CODES = ["U+0027", "U+2018", "U+2019"] + +_QUOTE_TABLE = {} + +for _code in _DOUBLE_QUOTE_CODES: + _ch = unicode_to_char(_code) + _QUOTE_TABLE[ord(_ch)] = ord('"') + +for _code in _SINGLE_QUOTE_CODES: + _ch = unicode_to_char(_code) + _QUOTE_TABLE[ord(_ch)] = ord("'") + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_QUOTE_TABLE) +''' + + # Simulated generated test code + generated_test_source = '''import pytest +from module import standardize_quotes + +def test_standardize_quotes_basic(): + """Test basic quote standardization.""" + result = standardize_quotes("Hello world") + assert result == "Hello world" + +def test_standardize_quotes_unicode_double(): + """Test unicode double quote conversion.""" + result = standardize_quotes("Say \\u201chello\\u201d") + assert '"' in result + +def test_standardize_quotes_empty(): + """Test empty string.""" + result = standardize_quotes("") + assert result == "" +''' + + # Step 1: Process optimization code through add_global_assignments + processed_optimization = add_global_assignments(optimized_code, original_code) + + # Verify optimization code compiles and executes without NameError + try: + compiled = compile(processed_optimization, "", "exec") + namespace = {} + exec(compiled, namespace) + + # Verify the translation table is populated + quote_table = namespace.get("_QUOTE_TABLE") + assert quote_table is not None, "_QUOTE_TABLE should be defined" + assert 8220 in quote_table, "Left double quote (U+201C) should be in table" + assert 8221 in quote_table, "Right double quote (U+201D) should be in table" + except NameError as e: + raise AssertionError(f"Optimization code failed with NameError: {e}\n\n{processed_optimization}") from e + + # Step 2: Process generated tests through remove_functions_from_generated_tests + generated_test = GeneratedTests( + generated_original_test_source=generated_test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_standardize_quotes.py"), + perf_file_path=Path("test_standardize_quotes_perf.py"), + ) + generated_tests_list = GeneratedTestsList(generated_tests=[generated_test]) + + # Simulate removing a failed test + processed_tests = remove_functions_from_generated_tests( + generated_tests_list, ["test_standardize_quotes_empty"] + ) + + # Verify the test was removed and others remain + result_source = processed_tests.generated_tests[0].generated_original_test_source + assert "test_standardize_quotes_basic" in result_source + assert "test_standardize_quotes_unicode_double" in result_source + assert "test_standardize_quotes_empty" not in result_source + + +def test_standardize_quotes_testgen_postprocessing_with_dict_comprehension() -> None: + """Test postprocessing pipeline for standardize_quotes with dict comprehension pattern. + + This simulates an optimization that uses dict comprehension with calls to + a helper function, ensuring the global assignments are correctly ordered. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + from codeflash.code_utils.edit_generated_tests import add_runtime_comments_to_generated_tests + from codeflash.models.models import GeneratedTests, GeneratedTestsList + + # Original code + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text + + +def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimization using dict comprehension (depends on unicode_to_char being defined first) + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +_DOUBLE_UNICODES = {"U+0022": '"', "U+201C": '"', "U+201D": '"'} +_SINGLE_UNICODES = {"U+0027": "'", "U+2018": "'", "U+2019": "'"} + +# Build translation table using comprehension +_TRANSLATION = { + ord(unicode_to_char(code)): ord(target) + for mapping in [_DOUBLE_UNICODES, _SINGLE_UNICODES] + for code, target in mapping.items() +} + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_TRANSLATION) +''' + + # Process optimization code + processed_optimization = add_global_assignments(optimized_code, original_code) + + # Verify code compiles and translation dict is built + try: + compiled = compile(processed_optimization, "", "exec") + namespace = {} + exec(compiled, namespace) + + translation = namespace.get("_TRANSLATION") + assert translation is not None, "_TRANSLATION should be defined" + # Verify the mapping contains unicode quote codes + assert len(translation) >= 4, "Translation table should have at least 4 entries" + except NameError as e: + raise AssertionError(f"Dict comprehension pattern failed with NameError: {e}") from e + + # Create mock generated tests with runtime data + generated_test_source = '''def test_standardize_double_quotes(): + result = standardize_quotes("\\u201chello\\u201d") + assert result == '"hello"' + +def test_standardize_single_quotes(): + result = standardize_quotes("\\u2018world\\u2019") + assert result == "'world'" +''' + + generated_test = GeneratedTests( + generated_original_test_source=generated_test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_quotes.py"), + perf_file_path=Path("test_quotes_perf.py"), + ) + generated_tests_list = GeneratedTestsList(generated_tests=[generated_test]) + + # Mock runtime data for add_runtime_comments_to_generated_tests + # (Empty dicts since we don't have actual runtime data in this test) + original_runtimes: dict = {} + optimized_runtimes: dict = {} + + # Process through runtime comments (should handle empty runtimes gracefully) + processed_tests = add_runtime_comments_to_generated_tests( + generated_tests_list, original_runtimes, optimized_runtimes + ) + + # Verify tests are still valid + result_source = processed_tests.generated_tests[0].generated_original_test_source + assert "test_standardize_double_quotes" in result_source + assert "test_standardize_single_quotes" in result_source + + +def test_standardize_quotes_testgen_full_pipeline_integration() -> None: + """Test complete integration of testgen postprocessing pipeline for standardize_quotes. + + This test simulates the full flow: + 1. Original code with the function to optimize + 2. LLM-generated optimization with module-level for-loops + 3. Processing through add_global_assignments + 4. Generated tests processing through remove_functions + add_runtime_comments + 5. Final verification that everything compiles and is correct + """ + from codeflash.code_utils.code_extractor import add_global_assignments + from codeflash.code_utils.edit_generated_tests import ( + add_runtime_comments_to_generated_tests, + remove_functions_from_generated_tests, + ) + from codeflash.models.models import GeneratedTests, GeneratedTestsList + + # Original FTO code + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes. + + Args: + text: Input text that may contain unicode quotes. + + Returns: + Text with unicode quotes replaced by ASCII quotes. + """ + double_quotes = {'"': "U+0022", '\u201c': "U+201C", '\u201d': "U+201D"} + single_quotes = {"'": "U+0027", '\u2018': "U+2018", '\u2019': "U+2019"} + + for char, unicode_val in double_quotes.items(): + if char != '"': + text = text.replace(char, '"') + + for char, unicode_val in single_quotes.items(): + if char != "'": + text = text.replace(char, "'") + + return text +''' + + # LLM-generated optimization with a NEW helper function and module-level call + # This tests that add_global_assignments transfers new function definitions + optimized_code = '''def _build_translation_table() -> dict[int, int]: + """Build translation table for quote standardization.""" + table = {} + # Double quotes + for code in [0x0022, 0x201C, 0x201D]: + table[code] = 0x0022 # Map to ASCII double quote + # Single quotes + for code in [0x0027, 0x2018, 0x2019]: + table[code] = 0x0027 # Map to ASCII single quote + return table + +_QUOTE_TRANSLATION_TABLE = _build_translation_table() + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_QUOTE_TRANSLATION_TABLE) +''' + + # Step 1: Process optimization through add_global_assignments + processed_code = add_global_assignments(optimized_code, original_code) + + # Verify optimization code structure - new function should be transferred + assert "_build_translation_table" in processed_code, "New helper function should be present" + assert "_QUOTE_TRANSLATION_TABLE = _build_translation_table()" in processed_code + + # Verify code executes without errors + namespace = {} + exec(compile(processed_code, "", "exec"), namespace) + + # Verify the translation table is correctly built + table = namespace["_QUOTE_TRANSLATION_TABLE"] + assert table[0x201C] == 0x0022, "Left double quote should map to ASCII double quote" + assert table[0x201D] == 0x0022, "Right double quote should map to ASCII double quote" + assert table[0x2018] == 0x0027, "Left single quote should map to ASCII single quote" + assert table[0x2019] == 0x0027, "Right single quote should map to ASCII single quote" + + # Step 2: Create mock generated tests + generated_test_source = '''import pytest + +def test_standardize_quotes_no_change(): + """Test text without unicode quotes.""" + result = standardize_quotes("Hello, World!") + assert result == "Hello, World!" + +def test_standardize_quotes_double_unicode(): + """Test double quote unicode conversion.""" + text = "She said \\u201cHello\\u201d" + result = standardize_quotes(text) + assert result == 'She said "Hello"' + +def test_standardize_quotes_single_unicode(): + """Test single quote unicode conversion.""" + text = "It\\u2019s a test" + result = standardize_quotes(text) + assert result == "It's a test" + +def test_standardize_quotes_mixed(): + """Test mixed quote types.""" + text = "\\u201cIt\\u2019s \\u2018great\\u2019\\u201d" + result = standardize_quotes(text) + assert result == '"It\\'s \\'great\\'\"' + +def test_standardize_quotes_failing(): + """This test will be removed as failed.""" + result = standardize_quotes("test") + assert result == "wrong" +''' + + generated_test = GeneratedTests( + generated_original_test_source=generated_test_source, + instrumented_behavior_test_source="instrumented_behavior_placeholder", + instrumented_perf_test_source="instrumented_perf_placeholder", + behavior_file_path=Path("tests/test_standardize_quotes__unit_test_0.py"), + perf_file_path=Path("tests/test_standardize_quotes__perf_test_0.py"), + ) + generated_tests_list = GeneratedTestsList(generated_tests=[generated_test]) + + # Step 3: Remove failed tests + tests_to_remove = ["test_standardize_quotes_failing"] + processed_tests = remove_functions_from_generated_tests(generated_tests_list, tests_to_remove) + + result_source = processed_tests.generated_tests[0].generated_original_test_source + assert "test_standardize_quotes_no_change" in result_source + assert "test_standardize_quotes_double_unicode" in result_source + assert "test_standardize_quotes_single_unicode" in result_source + assert "test_standardize_quotes_mixed" in result_source + assert "test_standardize_quotes_failing" not in result_source + + # Step 4: Add runtime comments (with empty data to test graceful handling) + final_tests = add_runtime_comments_to_generated_tests(processed_tests, {}, {}) + + # Verify final output is still valid Python + final_source = final_tests.generated_tests[0].generated_original_test_source + compile(final_source, "", "exec") # Should not raise + + # Count remaining test functions + test_count = final_source.count("def test_") + assert test_count == 4, f"Should have 4 tests after removing failed one, got {test_count}" + + +def test_match_statement_pattern_names(): + """Test that match statement patterns define names correctly.""" + from codeflash.code_utils.code_extractor import get_statement_defined_names + + # Test basic match with MatchAs + code = """ +match value: + case [x, y]: + result = x + y + case {"name": name}: + other = name +""" + module = cst.parse_module(code) + match_stmt = module.body[0] + defined_names = get_statement_defined_names(match_stmt) + + # Should find x, y from first case, name from second case, result, other from bodies + assert "x" in defined_names, "Pattern variable 'x' should be defined" + assert "y" in defined_names, "Pattern variable 'y' should be defined" + assert "name" in defined_names, "Pattern variable 'name' should be defined" + assert "result" in defined_names, "Body variable 'result' should be defined" + assert "other" in defined_names, "Body variable 'other' should be defined" + + +def test_match_statement_star_pattern(): + """Test that match statement star patterns define names correctly.""" + from codeflash.code_utils.code_extractor import get_statement_defined_names + + code = """ +match items: + case [first, *rest]: + total = sum(rest) +""" + module = cst.parse_module(code) + match_stmt = module.body[0] + defined_names = get_statement_defined_names(match_stmt) + + assert "first" in defined_names, "Pattern variable 'first' should be defined" + assert "rest" in defined_names, "Star pattern variable 'rest' should be defined" + assert "total" in defined_names, "Body variable 'total' should be defined" + + +def test_match_statement_as_pattern(): + """Test that match statement 'as' patterns define names correctly.""" + from codeflash.code_utils.code_extractor import get_statement_defined_names + + code = """ +match obj: + case [a, b] as both: + use = both +""" + module = cst.parse_module(code) + match_stmt = module.body[0] + defined_names = get_statement_defined_names(match_stmt) + + assert "a" in defined_names, "Pattern variable 'a' should be defined" + assert "b" in defined_names, "Pattern variable 'b' should be defined" + assert "both" in defined_names, "'as' capture variable 'both' should be defined" + + +def test_match_statement_mapping_rest(): + """Test that match statement mapping rest patterns define names correctly.""" + from codeflash.code_utils.code_extractor import get_statement_defined_names + + code = """ +match config: + case {"known": val, **rest}: + extra = rest +""" + module = cst.parse_module(code) + match_stmt = module.body[0] + defined_names = get_statement_defined_names(match_stmt) + + assert "val" in defined_names, "Mapping value 'val' should be defined" + assert "rest" in defined_names, "Mapping rest '**rest' should be defined" + + +def test_comprehension_variable_not_dependency(): + """Test that comprehension variables aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "result = [x * 2 for x in items]" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "x" not in deps, "Comprehension variable 'x' should not be a dependency" + assert "items" in deps, "Iterable 'items' should be a dependency" + + +def test_nested_comprehension_variables_not_dependencies(): + """Test that nested comprehension variables aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "matrix = [[row[col] for col in range(n)] for row in data]" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "row" not in deps, "Outer comprehension variable 'row' should not be a dependency" + assert "col" not in deps, "Inner comprehension variable 'col' should not be a dependency" + assert "data" in deps, "Outer iterable 'data' should be a dependency" + assert "n" in deps, "Inner range argument 'n' should be a dependency" + assert "range" in deps, "Built-in 'range' should be a dependency" + + +def test_dict_comprehension_variables_not_dependencies(): + """Test that dict comprehension variables aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "mapping = {k: v * 2 for k, v in items.items()}" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "k" not in deps, "Dict comprehension key variable 'k' should not be a dependency" + assert "v" not in deps, "Dict comprehension value variable 'v' should not be a dependency" + assert "items" in deps, "Iterable 'items' should be a dependency" + + +def test_generator_expression_variable_not_dependency(): + """Test that generator expression variables aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "gen = (x for x in items if x > 0)" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "x" not in deps, "Generator expression variable 'x' should not be a dependency" + assert "items" in deps, "Iterable 'items' should be a dependency" + + +def test_lambda_parameter_not_dependency(): + """Test that lambda parameters aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "handler = lambda x, y: x + y + z" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "x" not in deps, "Lambda parameter 'x' should not be a dependency" + assert "y" not in deps, "Lambda parameter 'y' should not be a dependency" + assert "z" in deps, "Free variable 'z' should be a dependency" + + +def test_lambda_star_args_not_dependency(): + """Test that lambda *args and **kwargs aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "handler = lambda *args, **kwargs: process(args, kwargs, config)" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "args" not in deps, "Lambda *args should not be a dependency" + assert "kwargs" not in deps, "Lambda **kwargs should not be a dependency" + assert "process" in deps, "Free variable 'process' should be a dependency" + assert "config" in deps, "Free variable 'config' should be a dependency" + + +def test_circular_dependency_warning(caplog): + """Test that circular dependencies produce a warning.""" + import logging + + from codeflash.code_utils.code_extractor import _sort_statements_by_dependencies + + code = """ +x = y + 1 +y = x + 1 +""" + module = cst.parse_module(code) + + with caplog.at_level(logging.WARNING): + result = _sort_statements_by_dependencies(list(module.body)) + + # Should emit a warning about circular dependency + assert any("Circular dependency detected" in record.message for record in caplog.records), ( + "Should log a warning about circular dependencies" + ) + # Should return original order when cycle is detected + assert len(result) == 2, "Should return all statements" + + +def test_add_global_assignments_with_match_statement(): + """Test that match statements in optimized code are handled correctly.""" + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = """def foo(): + pass +""" + + optimized_code = """ +match config: + case {"type": t}: + result = t + case _: + result = "default" + +use_result = result + +def foo(): + pass +""" + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code contains the match statement + assert "match config" in result, "Match statement should be in result" + assert "use_result = result" in result, "Variable using match result should be in result" + + # The code should be syntactically valid (but may have runtime NameErrors + # due to undefined 'config') + compiled = compile(result, "", "exec") + assert compiled is not None + + +def test_comprehension_in_global_assignment(): + """Test that global assignments with comprehensions work correctly.""" + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = """def foo(): + pass +""" + + optimized_code = """ +items = [1, 2, 3, 4, 5] +doubled = [x * 2 for x in items] +filtered = [x for x in doubled if x > 4] + +def foo(): + pass +""" + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + namespace = {} + exec(compiled, namespace) + assert namespace["doubled"] == [2, 4, 6, 8, 10] + assert namespace["filtered"] == [6, 8, 10] + except NameError as e: + msg = f"Comprehension test failed with NameError: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e + + +def test_lambda_in_global_assignment(): + """Test that global assignments with lambdas work correctly.""" + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = """def foo(): + pass +""" + + optimized_code = """ +multiplier = 2 +scale = lambda x: x * multiplier + +def foo(): + pass +""" + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + namespace = {} + exec(compiled, namespace) + assert namespace["scale"](5) == 10 + except NameError as e: + msg = f"Lambda test failed with NameError: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e diff --git a/uv.lock b/uv.lock index 7012850eb..e055ccb05 100644 --- a/uv.lock +++ b/uv.lock @@ -433,6 +433,7 @@ dependencies = [ { name = "pytest-asyncio", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "pytest-asyncio", version = "1.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "pytest-timeout" }, + { name = "pyyaml" }, { name = "rich" }, { name = "sentry-sdk" }, { name = "tomlkit" }, @@ -451,6 +452,7 @@ dev = [ { name = "pre-commit", version = "4.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "pre-commit", version = "4.5.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "ruff" }, + { name = "ty" }, { name = "types-cffi" }, { name = "types-colorama" }, { name = "types-decorator" }, @@ -517,6 +519,7 @@ requires-dist = [ { name = "pytest", specifier = ">=7.0.0" }, { name = "pytest-asyncio", specifier = ">=1.2.0" }, { name = "pytest-timeout", specifier = ">=2.1.0" }, + { name = "pyyaml", specifier = ">=6.0.3" }, { name = "rich", specifier = ">=13.8.1" }, { name = "sentry-sdk", specifier = ">=1.40.6,<3.0.0" }, { name = "tomlkit", specifier = ">=0.11.7" }, @@ -531,6 +534,7 @@ dev = [ { name = "pandas-stubs", specifier = ">=2.2.2.240807,<2.2.3.241009" }, { name = "pre-commit", specifier = ">=4.2.0,<5" }, { name = "ruff", specifier = ">=0.7.0" }, + { name = "ty", specifier = ">=0.0.12" }, { name = "types-cffi", specifier = ">=1.16.0.20240331" }, { name = "types-colorama", specifier = ">=0.4.15.20240311" }, { name = "types-decorator", specifier = ">=5.1.8.20240310" }, @@ -5154,6 +5158,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/16/b5/b0d3d8b901b6a04ca38df5e24c27e53afb15b93624d7fd7d658c7cd9352a/triton-3.5.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bac7f7d959ad0f48c0e97d6643a1cc0fd5786fe61cb1f83b537c6b2d54776478", size = 170582192, upload-time = "2025-11-11T17:41:23.963Z" }, ] +[[package]] +name = "ty" +version = "0.0.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/78/ba1a4ad403c748fbba8be63b7e774a90e80b67192f6443d624c64fe4aaab/ty-0.0.12.tar.gz", hash = "sha256:cd01810e106c3b652a01b8f784dd21741de9fdc47bd595d02c122a7d5cefeee7", size = 4981303, upload-time = "2026-01-14T22:30:48.537Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/8f/c21314d074dda5fb13d3300fa6733fd0d8ff23ea83a721818740665b6314/ty-0.0.12-py3-none-linux_armv6l.whl", hash = "sha256:eb9da1e2c68bd754e090eab39ed65edf95168d36cbeb43ff2bd9f86b4edd56d1", size = 9614164, upload-time = "2026-01-14T22:30:44.016Z" }, + { url = "https://files.pythonhosted.org/packages/09/28/f8a4d944d13519d70c486e8f96d6fa95647ac2aa94432e97d5cfec1f42f6/ty-0.0.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:c181f42aa19b0ed7f1b0c2d559980b1f1d77cc09419f51c8321c7ddf67758853", size = 9542337, upload-time = "2026-01-14T22:30:05.687Z" }, + { url = "https://files.pythonhosted.org/packages/e1/9c/f576e360441de7a8201daa6dc4ebc362853bc5305e059cceeb02ebdd9a48/ty-0.0.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1f829e1eecd39c3e1b032149db7ae6a3284f72fc36b42436e65243a9ed1173db", size = 8909582, upload-time = "2026-01-14T22:30:46.089Z" }, + { url = "https://files.pythonhosted.org/packages/d6/13/0898e494032a5d8af3060733d12929e3e7716db6c75eac63fa125730a3e7/ty-0.0.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f45162e7826e1789cf3374627883cdeb0d56b82473a0771923e4572928e90be3", size = 9384932, upload-time = "2026-01-14T22:30:13.769Z" }, + { url = "https://files.pythonhosted.org/packages/e4/1a/b35b6c697008a11d4cedfd34d9672db2f0a0621ec80ece109e13fca4dfef/ty-0.0.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d11fec40b269bec01e751b2337d1c7ffa959a2c2090a950d7e21c2792442cccd", size = 9453140, upload-time = "2026-01-14T22:30:11.131Z" }, + { url = "https://files.pythonhosted.org/packages/dd/1e/71c9edbc79a3c88a0711324458f29c7dbf6c23452c6e760dc25725483064/ty-0.0.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09d99e37e761a4d2651ad9d5a610d11235fbcbf35dc6d4bc04abf54e7cf894f1", size = 9960680, upload-time = "2026-01-14T22:30:33.621Z" }, + { url = "https://files.pythonhosted.org/packages/0e/75/39375129f62dd22f6ad5a99cd2a42fd27d8b91b235ce2db86875cdad397d/ty-0.0.12-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d9ca0cdb17bd37397da7b16a7cd23423fc65c3f9691e453ad46c723d121225a1", size = 10904518, upload-time = "2026-01-14T22:30:08.464Z" }, + { url = "https://files.pythonhosted.org/packages/32/5e/26c6d88fafa11a9d31ca9f4d12989f57782ec61e7291d4802d685b5be118/ty-0.0.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcf2757b905e7eddb7e456140066335b18eb68b634a9f72d6f54a427ab042c64", size = 10525001, upload-time = "2026-01-14T22:30:16.454Z" }, + { url = "https://files.pythonhosted.org/packages/c2/a5/2f0b91894af13187110f9ad7ee926d86e4e6efa755c9c88a820ed7f84c85/ty-0.0.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00cf34c1ebe1147efeda3021a1064baa222c18cdac114b7b050bbe42deb4ca80", size = 10307103, upload-time = "2026-01-14T22:30:41.221Z" }, + { url = "https://files.pythonhosted.org/packages/4b/77/13d0410827e4bc713ebb7fdaf6b3590b37dcb1b82e0a81717b65548f2442/ty-0.0.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb3a655bd869352e9a22938d707631ac9fbca1016242b1f6d132d78f347c851", size = 10072737, upload-time = "2026-01-14T22:30:51.783Z" }, + { url = "https://files.pythonhosted.org/packages/e1/dd/fc36d8bac806c74cf04b4ca735bca14d19967ca84d88f31e121767880df1/ty-0.0.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4658e282c7cb82be304052f8f64f9925f23c3c4f90eeeb32663c74c4b095d7ba", size = 9368726, upload-time = "2026-01-14T22:30:18.683Z" }, + { url = "https://files.pythonhosted.org/packages/54/70/9e8e461647550f83e2fe54bc632ccbdc17a4909644783cdbdd17f7296059/ty-0.0.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:c167d838eaaa06e03bb66a517f75296b643d950fbd93c1d1686a187e5a8dbd1f", size = 9454704, upload-time = "2026-01-14T22:30:22.759Z" }, + { url = "https://files.pythonhosted.org/packages/04/9b/6292cf7c14a0efeca0539cf7d78f453beff0475cb039fbea0eb5d07d343d/ty-0.0.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:2956e0c9ab7023533b461d8a0e6b2ea7b78e01a8dde0688e8234d0fce10c4c1c", size = 9649829, upload-time = "2026-01-14T22:30:31.234Z" }, + { url = "https://files.pythonhosted.org/packages/49/bd/472a5d2013371e4870886cff791c94abdf0b92d43d305dd0f8e06b6ff719/ty-0.0.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5c6a3fd7479580009f21002f3828320621d8a82d53b7ba36993234e3ccad58c8", size = 10162814, upload-time = "2026-01-14T22:30:36.174Z" }, + { url = "https://files.pythonhosted.org/packages/31/e9/2ecbe56826759845a7c21d80aa28187865ea62bc9757b056f6cbc06f78ed/ty-0.0.12-py3-none-win32.whl", hash = "sha256:a91c24fd75c0f1796d8ede9083e2c0ec96f106dbda73a09fe3135e075d31f742", size = 9140115, upload-time = "2026-01-14T22:30:38.903Z" }, + { url = "https://files.pythonhosted.org/packages/5d/6d/d9531eff35a5c0ec9dbc10231fac21f9dd6504814048e81d6ce1c84dc566/ty-0.0.12-py3-none-win_amd64.whl", hash = "sha256:df151894be55c22d47068b0f3b484aff9e638761e2267e115d515fcc9c5b4a4b", size = 9884532, upload-time = "2026-01-14T22:30:25.112Z" }, + { url = "https://files.pythonhosted.org/packages/e9/f3/20b49e75967023b123a221134548ad7000f9429f13fdcdda115b4c26305f/ty-0.0.12-py3-none-win_arm64.whl", hash = "sha256:cea99d334b05629de937ce52f43278acf155d3a316ad6a35356635f886be20ea", size = 9313974, upload-time = "2026-01-14T22:30:27.44Z" }, +] + [[package]] name = "types-cffi" version = "1.17.0.20250915"