From c64325ccebd24d8a4b26d793a1fc1f8042aa837a Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 18 Jan 2026 21:19:38 -0500 Subject: [PATCH 01/14] add ty to dev dependencies --- pyproject.toml | 1 + uv.lock | 29 ++++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1714532d0..bc9dc12db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ dev = [ "types-unidiff>=0.7.0.20240505,<0.8", "uv>=0.6.2", "pre-commit>=4.2.0,<5", + "ty>=0.0.12", ] tests = [ "black>=25.9.0", diff --git a/uv.lock b/uv.lock index 411a854ff..3012a779a 100644 --- a/uv.lock +++ b/uv.lock @@ -451,6 +451,7 @@ dev = [ { name = "pre-commit", version = "4.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "pre-commit", version = "4.5.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "ruff" }, + { name = "ty" }, { name = "types-cffi" }, { name = "types-colorama" }, { name = "types-decorator" }, @@ -531,6 +532,7 @@ dev = [ { name = "pandas-stubs", specifier = ">=2.2.2.240807,<2.2.3.241009" }, { name = "pre-commit", specifier = ">=4.2.0,<5" }, { name = "ruff", specifier = ">=0.7.0" }, + { name = "ty", specifier = ">=0.0.12" }, { name = "types-cffi", specifier = ">=1.16.0.20240331" }, { name = "types-colorama", specifier = ">=0.4.15.20240311" }, { name = "types-decorator", specifier = ">=5.1.8.20240310" }, @@ -925,7 +927,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -5154,6 +5156,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/16/b5/b0d3d8b901b6a04ca38df5e24c27e53afb15b93624d7fd7d658c7cd9352a/triton-3.5.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bac7f7d959ad0f48c0e97d6643a1cc0fd5786fe61cb1f83b537c6b2d54776478", size = 170582192, upload-time = "2025-11-11T17:41:23.963Z" }, ] +[[package]] +name = "ty" +version = "0.0.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/78/ba1a4ad403c748fbba8be63b7e774a90e80b67192f6443d624c64fe4aaab/ty-0.0.12.tar.gz", hash = "sha256:cd01810e106c3b652a01b8f784dd21741de9fdc47bd595d02c122a7d5cefeee7", size = 4981303, upload-time = "2026-01-14T22:30:48.537Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/8f/c21314d074dda5fb13d3300fa6733fd0d8ff23ea83a721818740665b6314/ty-0.0.12-py3-none-linux_armv6l.whl", hash = "sha256:eb9da1e2c68bd754e090eab39ed65edf95168d36cbeb43ff2bd9f86b4edd56d1", size = 9614164, upload-time = "2026-01-14T22:30:44.016Z" }, + { url = "https://files.pythonhosted.org/packages/09/28/f8a4d944d13519d70c486e8f96d6fa95647ac2aa94432e97d5cfec1f42f6/ty-0.0.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:c181f42aa19b0ed7f1b0c2d559980b1f1d77cc09419f51c8321c7ddf67758853", size = 9542337, upload-time = "2026-01-14T22:30:05.687Z" }, + { url = "https://files.pythonhosted.org/packages/e1/9c/f576e360441de7a8201daa6dc4ebc362853bc5305e059cceeb02ebdd9a48/ty-0.0.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1f829e1eecd39c3e1b032149db7ae6a3284f72fc36b42436e65243a9ed1173db", size = 8909582, upload-time = "2026-01-14T22:30:46.089Z" }, + { url = "https://files.pythonhosted.org/packages/d6/13/0898e494032a5d8af3060733d12929e3e7716db6c75eac63fa125730a3e7/ty-0.0.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f45162e7826e1789cf3374627883cdeb0d56b82473a0771923e4572928e90be3", size = 9384932, upload-time = "2026-01-14T22:30:13.769Z" }, + { url = "https://files.pythonhosted.org/packages/e4/1a/b35b6c697008a11d4cedfd34d9672db2f0a0621ec80ece109e13fca4dfef/ty-0.0.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d11fec40b269bec01e751b2337d1c7ffa959a2c2090a950d7e21c2792442cccd", size = 9453140, upload-time = "2026-01-14T22:30:11.131Z" }, + { url = "https://files.pythonhosted.org/packages/dd/1e/71c9edbc79a3c88a0711324458f29c7dbf6c23452c6e760dc25725483064/ty-0.0.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09d99e37e761a4d2651ad9d5a610d11235fbcbf35dc6d4bc04abf54e7cf894f1", size = 9960680, upload-time = "2026-01-14T22:30:33.621Z" }, + { url = "https://files.pythonhosted.org/packages/0e/75/39375129f62dd22f6ad5a99cd2a42fd27d8b91b235ce2db86875cdad397d/ty-0.0.12-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d9ca0cdb17bd37397da7b16a7cd23423fc65c3f9691e453ad46c723d121225a1", size = 10904518, upload-time = "2026-01-14T22:30:08.464Z" }, + { url = "https://files.pythonhosted.org/packages/32/5e/26c6d88fafa11a9d31ca9f4d12989f57782ec61e7291d4802d685b5be118/ty-0.0.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcf2757b905e7eddb7e456140066335b18eb68b634a9f72d6f54a427ab042c64", size = 10525001, upload-time = "2026-01-14T22:30:16.454Z" }, + { url = "https://files.pythonhosted.org/packages/c2/a5/2f0b91894af13187110f9ad7ee926d86e4e6efa755c9c88a820ed7f84c85/ty-0.0.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00cf34c1ebe1147efeda3021a1064baa222c18cdac114b7b050bbe42deb4ca80", size = 10307103, upload-time = "2026-01-14T22:30:41.221Z" }, + { url = "https://files.pythonhosted.org/packages/4b/77/13d0410827e4bc713ebb7fdaf6b3590b37dcb1b82e0a81717b65548f2442/ty-0.0.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb3a655bd869352e9a22938d707631ac9fbca1016242b1f6d132d78f347c851", size = 10072737, upload-time = "2026-01-14T22:30:51.783Z" }, + { url = "https://files.pythonhosted.org/packages/e1/dd/fc36d8bac806c74cf04b4ca735bca14d19967ca84d88f31e121767880df1/ty-0.0.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4658e282c7cb82be304052f8f64f9925f23c3c4f90eeeb32663c74c4b095d7ba", size = 9368726, upload-time = "2026-01-14T22:30:18.683Z" }, + { url = "https://files.pythonhosted.org/packages/54/70/9e8e461647550f83e2fe54bc632ccbdc17a4909644783cdbdd17f7296059/ty-0.0.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:c167d838eaaa06e03bb66a517f75296b643d950fbd93c1d1686a187e5a8dbd1f", size = 9454704, upload-time = "2026-01-14T22:30:22.759Z" }, + { url = "https://files.pythonhosted.org/packages/04/9b/6292cf7c14a0efeca0539cf7d78f453beff0475cb039fbea0eb5d07d343d/ty-0.0.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:2956e0c9ab7023533b461d8a0e6b2ea7b78e01a8dde0688e8234d0fce10c4c1c", size = 9649829, upload-time = "2026-01-14T22:30:31.234Z" }, + { url = "https://files.pythonhosted.org/packages/49/bd/472a5d2013371e4870886cff791c94abdf0b92d43d305dd0f8e06b6ff719/ty-0.0.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5c6a3fd7479580009f21002f3828320621d8a82d53b7ba36993234e3ccad58c8", size = 10162814, upload-time = "2026-01-14T22:30:36.174Z" }, + { url = "https://files.pythonhosted.org/packages/31/e9/2ecbe56826759845a7c21d80aa28187865ea62bc9757b056f6cbc06f78ed/ty-0.0.12-py3-none-win32.whl", hash = "sha256:a91c24fd75c0f1796d8ede9083e2c0ec96f106dbda73a09fe3135e075d31f742", size = 9140115, upload-time = "2026-01-14T22:30:38.903Z" }, + { url = "https://files.pythonhosted.org/packages/5d/6d/d9531eff35a5c0ec9dbc10231fac21f9dd6504814048e81d6ce1c84dc566/ty-0.0.12-py3-none-win_amd64.whl", hash = "sha256:df151894be55c22d47068b0f3b484aff9e638761e2267e115d515fcc9c5b4a4b", size = 9884532, upload-time = "2026-01-14T22:30:25.112Z" }, + { url = "https://files.pythonhosted.org/packages/e9/f3/20b49e75967023b123a221134548ad7000f9429f13fdcdda115b4c26305f/ty-0.0.12-py3-none-win_arm64.whl", hash = "sha256:cea99d334b05629de937ce52f43278acf155d3a316ad6a35356635f886be20ea", size = 9313974, upload-time = "2026-01-14T22:30:27.44Z" }, +] + [[package]] name = "types-cffi" version = "1.17.0.20250915" From 2d6854a598d427008902d2d909883c04cda6177a Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 18 Jan 2026 22:00:13 -0500 Subject: [PATCH 02/14] first working version --- codeflash/api/aiservice.py | 69 +++++++++ codeflash/cli_cmds/cli.py | 25 ++++ codeflash/models/models.py | 39 +++++ codeflash/optimization/function_optimizer.py | 147 +++++++++++++++++++ 4 files changed, 280 insertions(+) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 40d70e588..36a9b924a 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -193,6 +193,75 @@ def optimize_python_code( # noqa: D417 console.rule() return [] + def augmented_optimize( # noqa: D417 + self, + source_code: str, + system_prompt: str, + user_prompt: str, + trace_id: str, + dependency_code: str | None = None, + n_candidates: int = 3, + ) -> list[OptimizedCandidate]: + """Optimize code with custom prompts via /ai/augmented-optimize endpoint. + + Parameters + ---------- + - source_code (str): The python code to optimize (markdown format). + - system_prompt (str): Custom system prompt for the LLM. + - user_prompt (str): Custom user prompt for the LLM. + - trace_id (str): Trace id of optimization run. + - dependency_code (str | None): Optional dependency code context. + - n_candidates (int): Number of candidates to generate (max 3). + + Returns + ------- + - list[OptimizedCandidate]: A list of Optimization Candidates. + + """ + logger.info("Generating augmented optimization candidates…") + console.rule() + start_time = time.perf_counter() + git_repo_owner, git_repo_name = safe_get_repo_owner_and_name() + + payload = { + "source_code": source_code, + "dependency_code": dependency_code, + "trace_id": trace_id, + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "n_candidates": min(n_candidates, 3), + "python_version": platform.python_version(), + "codeflash_version": codeflash_version, + "current_username": get_last_commit_author_if_pr_exists(None), + "repo_owner": git_repo_owner, + "repo_name": git_repo_name, + } + logger.debug(f"Sending augmented optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}") + + try: + response = self.make_ai_service_request("/augmented-optimize", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating augmented optimization candidates: {e}") + ph("cli-augmented-optimize-error-caught", {"error": str(e)}) + console.rule() + return [] + + if response.status_code == 200: + optimizations_json = response.json()["optimizations"] + end_time = time.perf_counter() + logger.debug(f"!lsp|Generating augmented optimizations took {end_time - start_time:.2f} seconds.") + logger.info(f"!lsp|Received {len(optimizations_json)} augmented optimization candidates.") + console.rule() + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.AUGMENTED) + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating augmented optimization candidates: {response.status_code} - {error}") + ph("cli-augmented-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return [] + def get_jit_rewritten_code( # noqa: D417 self, source_code: str, trace_id: str ) -> list[OptimizedCandidate]: diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index df0ae57ce..db72ee9be 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -117,6 +117,22 @@ def parse_args() -> Namespace: parser.add_argument( "--effort", type=str, help="Effort level for optimization", choices=["low", "medium", "high"], default="medium" ) + parser.add_argument( + "--augmented", + action="store_true", + help="Enable augmented optimization mode for two-phase optimization workflow", + ) + parser.add_argument( + "--augmented-prompt-file", + type=str, + help="Path to YAML file with custom system_prompt and user_prompt for Phase 2", + ) + parser.add_argument( + "--augmented-output", + type=str, + default="codeflash_phase1_results.json", + help="Path to write Phase 1 results JSON (default: codeflash_phase1_results.json)", + ) args, unknown_args = parser.parse_known_args() sys.argv[:] = [sys.argv[0], *unknown_args] @@ -175,6 +191,15 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace: "Async function optimization is now enabled by default." ) + if args.augmented_prompt_file and not args.augmented: + exit_with_message("--augmented-prompt-file requires --augmented flag", error_on_exit=True) + + if args.augmented_prompt_file: + prompt_file = Path(args.augmented_prompt_file) + if not prompt_file.exists(): + exit_with_message(f"Augmented prompt file {args.augmented_prompt_file} does not exist", error_on_exit=True) + args.augmented_prompt_file = prompt_file.resolve() + return args diff --git a/codeflash/models/models.py b/codeflash/models/models.py index d850e3827..c812919cf 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -489,6 +489,7 @@ class OptimizedCandidateSource(str, Enum): REPAIR = "REPAIR" ADAPTIVE = "ADAPTIVE" JIT_REWRITE = "JIT_REWRITE" + AUGMENTED = "AUGMENTED" @dataclass(frozen=True) @@ -913,3 +914,41 @@ def __eq__(self, other: object) -> bool: return False sys.setrecursionlimit(original_recursion_limit) return True + + +class Phase1CandidateResult(BaseModel): + optimization_id: str + source_code: str + explanation: str + speedup_ratio: Optional[float] = None + runtime_ns: Optional[int] = None + is_correct: bool + line_profiler_results: Optional[str] = None + test_failures: Optional[list[str]] = None + test_diffs: Optional[list[dict]] = None + + +class Phase1FunctionResult(BaseModel): + function_name: str + trace_id: str + original_source_code: str + dependency_code: Optional[str] = None + original_runtime_ns: Optional[int] = None + original_line_profiler_results: Optional[str] = None + candidates: list[Phase1CandidateResult] + best_candidate_id: Optional[str] = None + best_speedup_ratio: Optional[float] = None + + +class Phase1Output(BaseModel): + codeflash_version: str + timestamp: str + python_version: str + functions: list[Phase1FunctionResult] + total_functions: int + successful_optimizations: int + + +class AugmentedPrompts(BaseModel): + system_prompt: str + user_prompt: str diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 761b8ea0c..e523a65c0 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -4,16 +4,19 @@ import concurrent.futures import logging import os +import platform import queue import random import subprocess import uuid from collections import defaultdict +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Callable import libcst as cst import sentry_sdk +import yaml from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -81,6 +84,7 @@ AdaptiveOptimizedCandidate, AIServiceAdaptiveOptimizeRequest, AIServiceCodeRepairRequest, + AugmentedPrompts, BestOptimization, CandidateEvaluationContext, CodeOptimizationContext, @@ -92,6 +96,9 @@ OptimizedCandidateResult, OptimizedCandidateSource, OriginalCodeBaseline, + Phase1CandidateResult, + Phase1FunctionResult, + Phase1Output, TestFile, TestFiles, TestingMode, @@ -468,6 +475,100 @@ def __init__( self.adaptive_optimization_counter = 0 # track how many adaptive optimizations we did for each function self.is_numerical_code: bool | None = None + self.augmented_mode = getattr(args, "augmented", False) if args else False + self.augmented_prompt_file = getattr(args, "augmented_prompt_file", None) if args else None + self.augmented_output = getattr(args, "augmented_output", "codeflash_phase1_results.json") if args else None + self.augmented_prompts: AugmentedPrompts | None = None + self.phase1_candidate_results: list[Phase1CandidateResult] = [] + + def load_augmented_prompts(self) -> AugmentedPrompts | None: + if not self.augmented_prompt_file: + return None + prompt_path = Path(self.augmented_prompt_file) + if not prompt_path.exists(): + logger.error(f"Augmented prompt file not found: {prompt_path}") + return None + with prompt_path.open(encoding="utf-8") as f: + data = yaml.safe_load(f) + if not data or "system_prompt" not in data or "user_prompt" not in data: + logger.error("Augmented prompt file must contain 'system_prompt' and 'user_prompt' keys") + return None + return AugmentedPrompts(system_prompt=data["system_prompt"], user_prompt=data["user_prompt"]) + + def collect_phase1_candidate_result( + self, + candidate: OptimizedCandidate, + eval_ctx: CandidateEvaluationContext, + test_failures: list[str] | None = None, + test_diffs: list[dict] | None = None, + ) -> Phase1CandidateResult: + speedup = eval_ctx.get_speedup_ratio(candidate.optimization_id) + runtime = eval_ctx.get_optimized_runtime(candidate.optimization_id) + is_correct = eval_ctx.is_correct.get(candidate.optimization_id, False) + line_profiler = eval_ctx.optimized_line_profiler_results.get(candidate.optimization_id) + return Phase1CandidateResult( + optimization_id=candidate.optimization_id, + source_code=candidate.source_code.markdown, + explanation=candidate.explanation, + speedup_ratio=speedup, + runtime_ns=int(runtime) if runtime else None, + is_correct=is_correct, + line_profiler_results=line_profiler, + test_failures=test_failures, + test_diffs=test_diffs, + ) + + def get_phase1_function_result( + self, + code_context: CodeOptimizationContext, + original_runtime_ns: int | None, + original_line_profiler_results: str | None, + best_candidate_id: str | None, + best_speedup_ratio: float | None, + ) -> Phase1FunctionResult: + return Phase1FunctionResult( + function_name=self.function_to_optimize.function_name, + trace_id=self.function_trace_id, + original_source_code=code_context.read_writable_code.markdown, + dependency_code=code_context.read_only_context_code if code_context.read_only_context_code else None, + original_runtime_ns=original_runtime_ns, + original_line_profiler_results=original_line_profiler_results, + candidates=self.phase1_candidate_results, + best_candidate_id=best_candidate_id, + best_speedup_ratio=best_speedup_ratio, + ) + + def write_phase1_output(self, function_result: Phase1FunctionResult) -> None: + from codeflash.version import __version__ as codeflash_version + + output_path = Path(self.augmented_output) if self.augmented_output else Path("codeflash_phase1_results.json") + output = Phase1Output( + codeflash_version=codeflash_version, + timestamp=datetime.now(tz=timezone.utc).isoformat(), + python_version=platform.python_version(), + functions=[function_result], + total_functions=1, + successful_optimizations=1 if function_result.best_candidate_id else 0, + ) + with output_path.open("w", encoding="utf-8") as f: + f.write(output.model_dump_json(indent=2)) + logger.info(f"Phase 1 results written to {output_path}") + + def get_augmented_candidates(self, code_context: CodeOptimizationContext) -> list[OptimizedCandidate]: + if not self.augmented_prompts: + logger.error("Augmented prompts not loaded") + return [] + candidates = self.aiservice_client.augmented_optimize( + source_code=code_context.read_writable_code.markdown, + system_prompt=self.augmented_prompts.system_prompt, + user_prompt=self.augmented_prompts.user_prompt, + trace_id=self.function_trace_id, + dependency_code=code_context.read_only_context_code if code_context.read_only_context_code else None, + n_candidates=3, + ) + logger.info(f"Received {len(candidates)} augmented optimization candidates") + return candidates + def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None logger.info(f"!lsp|Function Trace ID: {self.function_trace_id}") @@ -633,6 +734,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: read_only_context_code=code_context.read_only_context_code, run_experiment=should_run_experiment, is_numerical_code=self.is_numerical_code, + code_context=code_context, ) concurrent.futures.wait([future_tests, future_optimizations]) @@ -700,6 +802,23 @@ def optimize_function(self) -> Result[BestOptimization, str]: if self.args.override_fixtures: restore_conftest(original_conftest_content) + + if self.augmented_mode: + function_result = self.get_phase1_function_result( + code_context=code_context, + original_runtime_ns=original_code_baseline.runtime if original_code_baseline else None, + original_line_profiler_results=original_code_baseline.line_profile_results.get("str_out") + if original_code_baseline and original_code_baseline.line_profile_results + else None, + best_candidate_id=best_optimization.candidate.optimization_id if best_optimization else None, + best_speedup_ratio=performance_gain( + original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime + ) + if best_optimization + else None, + ) + self.write_phase1_output(function_result) + if not best_optimization: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) @@ -1195,6 +1314,11 @@ def determine_best_candidate( exp_type=exp_type, ) + if self.augmented_mode: + for candidate in candidates: + phase1_result = self.collect_phase1_candidate_result(candidate=candidate, eval_ctx=eval_ctx) + self.phase1_candidate_results.append(phase1_result) + return best_optimization def call_adaptive_optimize( @@ -1601,8 +1725,31 @@ def generate_optimizations( read_only_context_code: str, run_experiment: bool = False, # noqa: FBT001, FBT002 is_numerical_code: bool | None = None, # noqa: FBT001 + code_context: CodeOptimizationContext | None = None, ) -> Result[tuple[OptimizationSet, str], str]: """Generate optimization candidates for the function. Backend handles multi-model diversity.""" + if self.augmented_mode and self.augmented_prompt_file: + self.augmented_prompts = self.load_augmented_prompts() + if not self.augmented_prompts: + return Failure("Failed to load augmented prompts from file") + if code_context is None: + return Failure("Code context required for augmented mode") + augmented_candidates = self.get_augmented_candidates(code_context) + if not augmented_candidates: + return Failure( + f"/!\\ NO AUGMENTED OPTIMIZATIONS GENERATED for {self.function_to_optimize.function_name}" + ) + future_references = self.executor.submit( + get_opt_review_metrics, + self.function_to_optimize_source_code, + self.function_to_optimize.file_path, + self.function_to_optimize.qualified_name, + self.project_root, + self.test_cfg.tests_root, + ) + function_references = future_references.result() + return Success((OptimizationSet(control=augmented_candidates, experiment=None), function_references)) + n_candidates = get_effort_value(EffortKeys.N_OPTIMIZER_CANDIDATES, self.effort) future_optimization_candidates = self.executor.submit( self.aiservice_client.optimize_python_code, From 87c44e9194a6cf0c6af1717e8bb913706ff13c19 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 19 Jan 2026 00:09:14 -0500 Subject: [PATCH 03/14] fix: insert global assignments after their function dependencies When LLM-generated optimizations use module-level code that depends on functions defined later in the original file (e.g., `_TABLE = func(...)`), the assignments were being inserted after imports but before the function definitions, causing NameError at import time. This fix: - Adds NameCollector visitor to extract names from assignment values - Tracks positions of function/class definitions in the module - Inserts each assignment after all its dependencies are defined - Assignments without dependencies still go after imports Fixes optimization failures for functions like `standardize_quotes` that use helper functions like `unicode_to_char` defined later in the file. --- codeflash/code_utils/code_extractor.py | 72 +++++++++++++++++++++++--- tests/test_code_replacement.py | 65 +++++++++++++++++++++++ 2 files changed, 129 insertions(+), 8 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 66dfd5eb4..d7f9961c9 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -2,6 +2,7 @@ import ast import time +from collections import defaultdict from dataclasses import dataclass from importlib.util import find_spec from itertools import chain @@ -25,6 +26,27 @@ from codeflash.models.models import FunctionSource +class NameCollector(cst.CSTVisitor): + """Collects all Name nodes referenced in a CST expression. + + Used to find what names an assignment depends on (e.g., function calls in the RHS). + """ + + def __init__(self) -> None: + super().__init__() + self.names: set[str] = set() + + def visit_Name(self, node: cst.Name) -> None: + self.names.add(node.value) + + +def get_names_in_expression(node: cst.BaseExpression) -> set[str]: + """Extract all names referenced in a CST expression.""" + collector = NameCollector() + node.visit(collector) + return collector.names + + class GlobalAssignmentCollector(cst.CSTVisitor): """Collects all global assignment statements.""" @@ -154,20 +176,54 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c # Add any new assignments that weren't in the original file new_statements = list(updated_node.body) - # Find assignments to append + # Find assignments to append (with their names for dependency analysis) assignments_to_append = [ - self.new_assignments[name] + (name, self.new_assignments[name]) for name in self.new_assignment_order if name not in self.processed_assignments and name in self.new_assignments ] - if assignments_to_append: - # after last top-level imports - insert_index = find_insertion_index_after_imports(updated_node) - + if not assignments_to_append: + return updated_node.with_changes(body=new_statements) + + # Build a map of where functions/classes/assignments are defined in the module + # Maps name -> index AFTER the definition (i.e., the first valid insertion point) + definition_positions: dict[str, int] = {} + for i, stmt in enumerate(new_statements): + if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): + definition_positions[stmt.name.value] = i + 1 + elif isinstance(stmt, cst.SimpleStatementLine): + # Check for assignments + for child in stmt.body: + if isinstance(child, cst.Assign): + for target in child.targets: + if isinstance(target.target, cst.Name): + definition_positions[target.target.value] = i + 1 + + # Find the default insertion index (after imports) + default_insert_index = find_insertion_index_after_imports(updated_node) + + # For each assignment, determine its minimum insertion index + # by finding the max position of all functions/names it depends on + assignments_by_insert_index: dict[int, list[cst.Assign]] = defaultdict(list) + for _name, assignment in assignments_to_append: + # Get names referenced in the assignment's value (RHS) + dependencies = get_names_in_expression(assignment.value) + + # Find the minimum insertion index (after all dependencies are defined) + min_insert_index = default_insert_index + for dep_name in dependencies: + if dep_name in definition_positions: + min_insert_index = max(min_insert_index, definition_positions[dep_name]) + + assignments_by_insert_index[min_insert_index].append(assignment) + + # Insert assignments at their appropriate positions + # Process in reverse order of indices to avoid offset issues when inserting + for insert_index in sorted(assignments_by_insert_index.keys(), reverse=True): + assignments = assignments_by_insert_index[insert_index] assignment_lines = [ - cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) - for assignment in assignments_to_append + cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) for assignment in assignments ] new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 04d83f13f..337150701 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -3743,3 +3743,68 @@ def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]: project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == expected + + +def test_global_assignments_with_function_dependencies() -> None: + """Test that global assignments that depend on functions are inserted after those functions. + + This tests the fix for a bug where LLM-generated optimizations that use module-level + code like `_TABLE = unicode_to_char(...)` would fail because the assignment was inserted + before the function definition. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original file: standardize_quotes first, unicode_to_char second + original_code = '''def standardize_quotes(text: str) -> str: + """Standardize quotes in text.""" + return text + +def unicode_to_char(unicode_val: str) -> str: + """Convert unicode value to char.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimized code: defines unicode_to_char first, then module-level code that uses it + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Convert unicode value to char.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +_CODES = ("U+0022", "U+201C") +_TRANSLATION_TABLE = {ord(unicode_to_char(c)): ord('"') for c in _CODES} + +def standardize_quotes(text: str) -> str: + """Standardize quotes in text.""" + return text.translate(_TRANSLATION_TABLE) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The assignment that depends on unicode_to_char should be inserted AFTER unicode_to_char + # not after imports (which would cause a NameError) + + # Parse the result and verify the order + import libcst as cst + + module = cst.parse_module(result) + + # Find positions of key elements + unicode_to_char_pos = None + translation_table_pos = None + + for i, stmt in enumerate(module.body): + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == "unicode_to_char": + unicode_to_char_pos = i + elif isinstance(stmt, cst.SimpleStatementLine): + for child in stmt.body: + if isinstance(child, cst.Assign): + for target in child.targets: + if isinstance(target.target, cst.Name) and target.target.value == "_TRANSLATION_TABLE": + translation_table_pos = i + + # Verify that _TRANSLATION_TABLE comes AFTER unicode_to_char + assert unicode_to_char_pos is not None, "unicode_to_char function not found in result" + assert translation_table_pos is not None, "_TRANSLATION_TABLE assignment not found in result" + assert translation_table_pos > unicode_to_char_pos, ( + f"_TRANSLATION_TABLE (pos {translation_table_pos}) should be after " + f"unicode_to_char (pos {unicode_to_char_pos}) because it depends on it" + ) From c5f5710fadb6b30476908978020c828e7747eaba Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 19 Jan 2026 00:10:50 -0500 Subject: [PATCH 04/14] feat: add create-pr CLI command for creating PRs from augmented optimization results Adds a new `codeflash create-pr` subcommand that creates PRs from previously applied optimizations stored in a JSON results file. This enables a two-phase workflow where optimizations are applied locally first, then PRs can be created separately. --- codeflash/cli_cmds/cli.py | 15 ++ codeflash/cli_cmds/cmd_create_pr.py | 145 +++++++++++++++++++ codeflash/models/models.py | 7 + codeflash/optimization/function_optimizer.py | 21 +++ 4 files changed, 188 insertions(+) create mode 100644 codeflash/cli_cmds/cmd_create_pr.py diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index db72ee9be..2de7a804d 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -6,6 +6,7 @@ from codeflash.cli_cmds import logging_config from codeflash.cli_cmds.cli_common import apologize_and_exit +from codeflash.cli_cmds.cmd_create_pr import create_pr as cmd_create_pr from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions from codeflash.cli_cmds.console import logger from codeflash.cli_cmds.extension import install_vscode_extension @@ -57,6 +58,20 @@ def parse_args() -> Namespace: help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.", ) + # create-pr subcommand for creating PRs from augmented optimization results + create_pr_parser = subparsers.add_parser("create-pr", help="Create a PR from previously applied optimizations") + create_pr_parser.set_defaults(func=cmd_create_pr) + create_pr_parser.add_argument( + "--results-file", + type=str, + default="codeflash_phase1_results.json", + help="Path to augmented output JSON file (default: codeflash_phase1_results.json)", + ) + create_pr_parser.add_argument( + "--function", type=str, help="Function name (required if multiple functions in results)" + ) + create_pr_parser.add_argument("--git-remote", type=str, help="Git remote to use for PR creation (default: origin)") + parser.add_argument("--file", help="Try to optimize only this file") parser.add_argument("--function", help="Try to optimize only this function within the given file path") parser.add_argument( diff --git a/codeflash/cli_cmds/cmd_create_pr.py b/codeflash/cli_cmds/cmd_create_pr.py new file mode 100644 index 000000000..d5ca3946f --- /dev/null +++ b/codeflash/cli_cmds/cmd_create_pr.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import json +import re +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash.cli_cmds.console import logger +from codeflash.code_utils.code_utils import exit_with_message +from codeflash.code_utils.git_utils import git_root_dir +from codeflash.models.models import Phase1Output, TestResults +from codeflash.result.create_pr import check_create_pr +from codeflash.result.explanation import Explanation + +if TYPE_CHECKING: + from argparse import Namespace + +# Pattern to extract file path from markdown code block: ```python:path/to/file.py +MARKDOWN_FILE_PATH_PATTERN = re.compile(r"```python:([^\n]+)") + + +def extract_file_path_from_markdown(markdown_code: str) -> str | None: + """Extract the file path from markdown code block format. + + Format: ```python:path/to/file.py + """ + match = MARKDOWN_FILE_PATH_PATTERN.search(markdown_code) + if match: + return match.group(1).strip() + return None + + +def extract_code_from_markdown(markdown_code: str) -> str: + r"""Extract the code content from markdown code block. + + Removes the ```python:path\n ... ``` wrapper. + """ + # Remove opening markdown fence with optional path + code = re.sub(r"^```python(?::[^\n]*)?\n", "", markdown_code) + # Remove closing fence + return re.sub(r"\n```$", "", code) + + +def create_pr(args: Namespace) -> None: + """Create a PR from previously applied optimizations.""" + results_file = Path(args.results_file) + + if not results_file.exists(): + exit_with_message(f"Results file not found: {results_file}", error_on_exit=True) + + # Load and parse results + with results_file.open(encoding="utf-8") as f: + data = json.load(f) + + try: + output = Phase1Output.model_validate(data) + except Exception as e: + exit_with_message(f"Failed to parse results file: {e}", error_on_exit=True) + + # Find the function result + if len(output.functions) == 0: + exit_with_message("No functions in results file", error_on_exit=True) + + if len(output.functions) > 1 and not args.function: + func_names = [f.function_name for f in output.functions] + exit_with_message( + f"Multiple functions in results. Specify one with --function: {func_names}", error_on_exit=True + ) + + func_result = output.functions[0] + if args.function: + func_result = next((f for f in output.functions if f.function_name == args.function), None) + if not func_result: + exit_with_message(f"Function {args.function} not found in results", error_on_exit=True) + assert func_result is not None # for type checker - exit_with_message doesn't return + + if not func_result.best_candidate_id: + exit_with_message("No successful optimization found in results", error_on_exit=True) + + # Get file path - prefer explicit field, fall back to extracting from markdown + file_path_str = func_result.file_path + if not file_path_str: + file_path_str = extract_file_path_from_markdown(func_result.original_source_code) + + if not file_path_str: + exit_with_message( + "Could not determine file path from results. Results file may be from an older version of codeflash.", + error_on_exit=True, + ) + assert file_path_str is not None # for type checker - exit_with_message doesn't return + + file_path = Path(file_path_str) + if not file_path.exists(): + exit_with_message(f"Source file not found: {file_path}", error_on_exit=True) + + # Read current (optimized) file content + current_content = file_path.read_text(encoding="utf-8") + + # Extract original code (strip markdown) + original_code = extract_code_from_markdown(func_result.original_source_code) + + # Get the best candidate's explanation + best_explanation = func_result.best_candidate_explanation + if not best_explanation: + # Fall back to the candidate's explanation if the final explanation wasn't captured + best_candidate = next( + (c for c in func_result.candidates if c.optimization_id == func_result.best_candidate_id), None + ) + best_explanation = best_candidate.explanation if best_candidate else "Optimization applied" + + # Build Explanation object for PR creation + explanation = Explanation( + raw_explanation_message=best_explanation, + winning_behavior_test_results=TestResults(), + winning_benchmarking_test_results=TestResults(), + original_runtime_ns=func_result.original_runtime_ns or 0, + best_runtime_ns=func_result.best_runtime_ns or func_result.original_runtime_ns or 0, + function_name=func_result.function_name, + file_path=file_path, + ) + + logger.info(f"Creating PR for optimized function: {func_result.function_name}") + logger.info(f"File: {file_path}") + if func_result.best_speedup_ratio: + logger.info(f"Speedup: {func_result.best_speedup_ratio * 100:.1f}%") + + # Call existing PR creation + check_create_pr( + original_code={file_path: original_code}, + new_code={file_path: current_content}, + explanation=explanation, + existing_tests_source=func_result.existing_tests_source or "", + generated_original_test_source="", + function_trace_id=func_result.trace_id, + coverage_message="", + replay_tests=func_result.replay_tests_source or "", + concolic_tests=func_result.concolic_tests_source or "", + optimization_review="", + root_dir=git_root_dir(), + git_remote=getattr(args, "git_remote", None), + ) + + # Cleanup results file after successful PR creation + results_file.unlink() + logger.info(f"Cleaned up results file: {results_file}") diff --git a/codeflash/models/models.py b/codeflash/models/models.py index c812919cf..d886aa633 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -938,6 +938,13 @@ class Phase1FunctionResult(BaseModel): candidates: list[Phase1CandidateResult] best_candidate_id: Optional[str] = None best_speedup_ratio: Optional[float] = None + # PR creation data - captured after best candidate is selected + file_path: Optional[str] = None + existing_tests_source: Optional[str] = None + replay_tests_source: Optional[str] = None + concolic_tests_source: Optional[str] = None + best_candidate_explanation: Optional[str] = None + best_runtime_ns: Optional[int] = None class Phase1Output(BaseModel): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index e523a65c0..54d67220f 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -480,6 +480,12 @@ def __init__( self.augmented_output = getattr(args, "augmented_output", "codeflash_phase1_results.json") if args else None self.augmented_prompts: AugmentedPrompts | None = None self.phase1_candidate_results: list[Phase1CandidateResult] = [] + # PR data captured during process_review for Phase1 output + self.phase1_existing_tests: str | None = None + self.phase1_replay_tests: str | None = None + self.phase1_concolic_tests: str | None = None + self.phase1_best_explanation: str | None = None + self.phase1_best_runtime: int | None = None def load_augmented_prompts(self) -> AugmentedPrompts | None: if not self.augmented_prompt_file: @@ -536,6 +542,13 @@ def get_phase1_function_result( candidates=self.phase1_candidate_results, best_candidate_id=best_candidate_id, best_speedup_ratio=best_speedup_ratio, + # PR creation data + file_path=self.function_to_optimize.file_path.as_posix(), + existing_tests_source=self.phase1_existing_tests, + replay_tests_source=self.phase1_replay_tests, + concolic_tests_source=self.phase1_concolic_tests, + best_candidate_explanation=self.phase1_best_explanation, + best_runtime_ns=self.phase1_best_runtime, ) def write_phase1_output(self, function_result: Phase1FunctionResult) -> None: @@ -2079,6 +2092,14 @@ def process_review( best_optimization.explanation_v2 = new_explanation.explanation_message() + # Capture PR data for Phase1 output in augmented mode + if self.augmented_mode: + self.phase1_existing_tests = existing_tests + self.phase1_replay_tests = replay_tests + self.phase1_concolic_tests = concolic_tests + self.phase1_best_explanation = new_explanation.explanation_message() + self.phase1_best_runtime = best_optimization.runtime + data = { "original_code": original_code_combined, "new_code": new_code_combined, From 10389b52731ec52fc64dbf6b567669c53eb5ff98 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 19 Jan 2026 00:40:39 -0500 Subject: [PATCH 05/14] fix: transfer module-level compound statements (for/while/with/try) in add_global_assignments Previously, GlobalStatementCollector only collected SimpleStatementLine nodes, causing module-level for-loops (and other compound statements) to be dropped. This caused NameError for loop variables like 'uval', 'unicode_val', 'ch' when LLM-generated optimizations used for-loops to build translation tables. - Add compound statement collection to GlobalStatementCollector - Update ImportInserter to handle both simple and compound statements - Add tests for all reported NameError cases --- codeflash/code_utils/code_extractor.py | 179 ++++++++++--- tests/test_code_replacement.py | 340 +++++++++++++++++++++++++ 2 files changed, 487 insertions(+), 32 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index d7f9961c9..c3bf18e45 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -48,7 +48,13 @@ def get_names_in_expression(node: cst.BaseExpression) -> set[str]: class GlobalAssignmentCollector(cst.CSTVisitor): - """Collects all global assignment statements.""" + """Collects all global assignment statements. + + Only collects simple assignments at module level, NOT inside: + - Functions/classes (scope_depth) + - If/else blocks (if_else_depth) + - For/while loops, with statements, try blocks (compound_depth) + """ def __init__(self) -> None: super().__init__() @@ -57,6 +63,9 @@ def __init__(self) -> None: # Track scope depth to identify global assignments self.scope_depth = 0 self.if_else_depth = 0 + # Track other compound statements (for, while, with, try) where assignments + # inside them depend on loop variables or context managers + self.compound_depth = 0 def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: # noqa: ARG002 self.scope_depth += 1 @@ -83,9 +92,37 @@ def visit_Else(self, node: cst.Else) -> Optional[bool]: # noqa: ARG002 # Else blocks are already counted as part of the if statement return True + def visit_For(self, node: cst.For) -> Optional[bool]: # noqa: ARG002 + self.compound_depth += 1 + return True + + def leave_For(self, original_node: cst.For) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_While(self, node: cst.While) -> Optional[bool]: # noqa: ARG002 + self.compound_depth += 1 + return True + + def leave_While(self, original_node: cst.While) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_With(self, node: cst.With) -> Optional[bool]: # noqa: ARG002 + self.compound_depth += 1 + return True + + def leave_With(self, original_node: cst.With) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_Try(self, node: cst.Try) -> Optional[bool]: # noqa: ARG002 + self.compound_depth += 1 + return True + + def leave_Try(self, original_node: cst.Try) -> None: # noqa: ARG002 + self.compound_depth -= 1 + def visit_Assign(self, node: cst.Assign) -> Optional[bool]: - # Only process global assignments (not inside functions, classes, etc.) - if self.scope_depth == 0 and self.if_else_depth == 0: # We're at module level + # Only process global assignments (not inside functions, classes, loops, etc.) + if self.scope_depth == 0 and self.if_else_depth == 0 and self.compound_depth == 0: for target in node.targets: if isinstance(target.target, cst.Name): name = target.target.value @@ -123,7 +160,13 @@ def find_insertion_index_after_imports(node: cst.Module) -> int: class GlobalAssignmentTransformer(cst.CSTTransformer): - """Transforms global assignments in the original file with those from the new file.""" + """Transforms global assignments in the original file with those from the new file. + + Only transforms simple assignments at module level, NOT inside: + - Functions/classes (scope_depth) + - If/else blocks (if_else_depth) + - For/while loops, with statements, try blocks (compound_depth) + """ def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: list[str]) -> None: super().__init__() @@ -132,6 +175,7 @@ def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: self.processed_assignments: set[str] = set() self.scope_depth = 0 self.if_else_depth = 0 + self.compound_depth = 0 def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002 self.scope_depth += 1 @@ -158,8 +202,36 @@ def visit_Else(self, node: cst.Else) -> None: # Else blocks are already counted as part of the if statement pass + def visit_For(self, node: cst.For) -> None: # noqa: ARG002 + self.compound_depth += 1 + + def leave_For(self, original_node: cst.For, updated_node: cst.For) -> cst.For: # noqa: ARG002 + self.compound_depth -= 1 + return updated_node + + def visit_While(self, node: cst.While) -> None: # noqa: ARG002 + self.compound_depth += 1 + + def leave_While(self, original_node: cst.While, updated_node: cst.While) -> cst.While: # noqa: ARG002 + self.compound_depth -= 1 + return updated_node + + def visit_With(self, node: cst.With) -> None: # noqa: ARG002 + self.compound_depth += 1 + + def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.With: # noqa: ARG002 + self.compound_depth -= 1 + return updated_node + + def visit_Try(self, node: cst.Try) -> None: # noqa: ARG002 + self.compound_depth += 1 + + def leave_Try(self, original_node: cst.Try, updated_node: cst.Try) -> cst.Try: # noqa: ARG002 + self.compound_depth -= 1 + return updated_node + def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode: - if self.scope_depth > 0 or self.if_else_depth > 0: + if self.scope_depth > 0 or self.if_else_depth > 0 or self.compound_depth > 0: return updated_node # Check if this is a global assignment we need to replace @@ -243,12 +315,17 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c class GlobalStatementCollector(cst.CSTVisitor): - """Visitor that collects all global statements (excluding imports and functions/classes).""" + """Visitor that collects all global statements (excluding imports and functions/classes). + + Collects both simple statements and compound statements (for, while, with, try) + at module level. + """ def __init__(self) -> None: super().__init__() - self.global_statements = [] + self.global_statements: list[cst.BaseStatement] = [] self.in_function_or_class = False + self.compound_depth = 0 def visit_ClassDef(self, node: cst.ClassDef) -> bool: # noqa: ARG002 # Don't visit inside classes @@ -266,8 +343,44 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # noqa: ARG002 def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002 self.in_function_or_class = False + def visit_For(self, node: cst.For) -> Optional[bool]: + if not self.in_function_or_class and self.compound_depth == 0: + self.global_statements.append(node) + self.compound_depth += 1 + return False # Don't visit children - we collect the whole node + + def leave_For(self, original_node: cst.For) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_While(self, node: cst.While) -> Optional[bool]: + if not self.in_function_or_class and self.compound_depth == 0: + self.global_statements.append(node) + self.compound_depth += 1 + return False # Don't visit children - we collect the whole node + + def leave_While(self, original_node: cst.While) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_With(self, node: cst.With) -> Optional[bool]: + if not self.in_function_or_class and self.compound_depth == 0: + self.global_statements.append(node) + self.compound_depth += 1 + return False # Don't visit children - we collect the whole node + + def leave_With(self, original_node: cst.With) -> None: # noqa: ARG002 + self.compound_depth -= 1 + + def visit_Try(self, node: cst.Try) -> Optional[bool]: + if not self.in_function_or_class and self.compound_depth == 0: + self.global_statements.append(node) + self.compound_depth += 1 + return False # Don't visit children - we collect the whole node + + def leave_Try(self, original_node: cst.Try) -> None: # noqa: ARG002 + self.compound_depth -= 1 + def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: - if not self.in_function_or_class: + if not self.in_function_or_class and self.compound_depth == 0: for statement in node.body: # Skip imports if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)): @@ -366,40 +479,42 @@ def visit_Try(self, node: cst.Try) -> None: class ImportInserter(cst.CSTTransformer): - """Transformer that inserts global statements after the last import.""" + """Transformer that inserts global statements after the last import. - def __init__(self, global_statements: list[cst.SimpleStatementLine], last_import_line: int) -> None: + Handles both simple statements and compound statements (for, while, with, try). + """ + + def __init__(self, global_statements: list[cst.BaseStatement], last_import_line: int) -> None: super().__init__() self.global_statements = global_statements self.last_import_line = last_import_line - self.current_line = 0 - self.inserted = False - def leave_SimpleStatementLine( - self, - original_node: cst.SimpleStatementLine, # noqa: ARG002 - updated_node: cst.SimpleStatementLine, - ) -> cst.Module: - self.current_line += 1 + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 + if not self.global_statements: + return updated_node - # If we're right after the last import and haven't inserted yet - if self.current_line == self.last_import_line and not self.inserted: - self.inserted = True - return cst.Module(body=[updated_node, *self.global_statements]) + # If no imports, insert at the beginning + if self.last_import_line == 0: + updated_body = list(self.global_statements) + list(updated_node.body) + return updated_node.with_changes(body=updated_body) - return cst.Module(body=[updated_node]) + # Find the insertion point: after the last import line + # last_import_line is 1-indexed, so compare with i + 1 + insertion_index = 0 + for i in range(len(updated_node.body)): + if i + 1 == self.last_import_line: + insertion_index = i + 1 + break - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 - # If there were no imports, add at the beginning of the module - if self.last_import_line == 0 and not self.inserted: - updated_body = list(updated_node.body) - for stmt in reversed(self.global_statements): - updated_body.insert(0, stmt) - return updated_node.with_changes(body=updated_body) - return updated_node + # Insert global statements at the insertion point + updated_body = list(updated_node.body) + for j, stmt in enumerate(self.global_statements): + updated_body.insert(insertion_index + j, stmt) + + return updated_node.with_changes(body=updated_body) -def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]: +def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.BaseStatement]]: """Extract global statements from source code.""" module = cst.parse_module(source_code) collector = GlobalStatementCollector() diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 337150701..fb3e913af 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -3808,3 +3808,343 @@ def standardize_quotes(text: str) -> str: f"_TRANSLATION_TABLE (pos {translation_table_pos}) should be after " f"unicode_to_char (pos {unicode_to_char_pos}) because it depends on it" ) + + +def test_global_assignments_inside_for_loops_not_extracted() -> None: + """Test that assignments inside for-loops are NOT extracted as standalone globals. + + This tests the fix for a bug where LLM-generated optimizations that build + translation tables using for-loops would have the loop body assignments + incorrectly extracted, causing NameError for loop variables. + """ + from codeflash.code_utils.code_extractor import GlobalAssignmentCollector + + import libcst as cst + + # Code with assignments inside a for-loop (common optimization pattern) + # Note: Using regular assignment (not annotated) since GlobalAssignmentCollector only handles Assign + code_with_for_loop = ''' +double_quotes = {"a": "U+0022", "b": "U+201C"} + +_QUOTE_TRANSLATION = {} +for unicode_val in double_quotes.values(): + ch = unicode_to_char(unicode_val) + _QUOTE_TRANSLATION[ord(ch)] = _double_quote_standard + +def standardize_quotes(text: str) -> str: + return text.translate(_QUOTE_TRANSLATION) +''' + + module = cst.parse_module(code_with_for_loop) + collector = GlobalAssignmentCollector() + module.visit(collector) + + # Only the top-level assignments should be collected, NOT the one inside the for-loop + # _QUOTE_TRANSLATION = {} is at module level (should be collected) + # _QUOTE_TRANSLATION[ord(ch)] = ... is inside for-loop (should NOT be collected) + # double_quotes = {...} is at module level (should be collected) + assert "double_quotes" in collector.assignments, "double_quotes should be collected" + assert "_QUOTE_TRANSLATION" in collector.assignments, "_QUOTE_TRANSLATION init should be collected" + + # The assignment inside the for-loop uses subscript, not Name, so it wouldn't be + # collected anyway. But let's verify the collector doesn't crash and works correctly. + # More importantly, verify that simple name assignments inside loops are NOT collected. + + code_with_simple_assignment_in_loop = ''' +result = {} +for item in items: + key = process(item) + result[key] = item +''' + module2 = cst.parse_module(code_with_simple_assignment_in_loop) + collector2 = GlobalAssignmentCollector() + module2.visit(collector2) + + assert "result" in collector2.assignments, "result should be collected (top-level)" + assert "key" not in collector2.assignments, "key should NOT be collected (inside for-loop)" + + +def test_global_assignments_inside_while_loops_not_extracted() -> None: + """Test that assignments inside while-loops are NOT extracted.""" + from codeflash.code_utils.code_extractor import GlobalAssignmentCollector + + import libcst as cst + + code_with_while_loop = ''' +counter = 0 +while counter < 10: + value = compute(counter) + counter += 1 +''' + module = cst.parse_module(code_with_while_loop) + collector = GlobalAssignmentCollector() + module.visit(collector) + + assert "counter" in collector.assignments, "counter should be collected (top-level)" + assert "value" not in collector.assignments, "value should NOT be collected (inside while-loop)" + + +def test_global_assignments_inside_with_blocks_not_extracted() -> None: + """Test that assignments inside with-blocks are NOT extracted.""" + from codeflash.code_utils.code_extractor import GlobalAssignmentCollector + + import libcst as cst + + code_with_with_block = ''' +config = {} +with open("file.txt") as f: + content = f.read() + data = parse(content) +''' + module = cst.parse_module(code_with_with_block) + collector = GlobalAssignmentCollector() + module.visit(collector) + + assert "config" in collector.assignments, "config should be collected (top-level)" + assert "content" not in collector.assignments, "content should NOT be collected (inside with-block)" + assert "data" not in collector.assignments, "data should NOT be collected (inside with-block)" + + +def test_global_assignments_inside_try_blocks_not_extracted() -> None: + """Test that assignments inside try-blocks are NOT extracted.""" + from codeflash.code_utils.code_extractor import GlobalAssignmentCollector + + import libcst as cst + + code_with_try_block = ''' +default = "fallback" +try: + result = risky_operation() + processed = transform(result) +except Exception: + pass +''' + module = cst.parse_module(code_with_try_block) + collector = GlobalAssignmentCollector() + module.visit(collector) + + assert "default" in collector.assignments, "default should be collected (top-level)" + assert "result" not in collector.assignments, "result should NOT be collected (inside try-block)" + assert "processed" not in collector.assignments, "processed should NOT be collected (inside try-block)" + + +def test_global_assignment_transformer_ignores_loop_assignments() -> None: + """Test that GlobalAssignmentTransformer doesn't replace assignments inside loops. + + The transformer should: + 1. NOT replace assignments inside for/while/with/try blocks + 2. Still add new top-level assignments that weren't in the original + """ + from codeflash.code_utils.code_extractor import GlobalAssignmentTransformer + + import libcst as cst + + # Original code has 'key' inside a for-loop and 'result' at top level + original_code = ''' +result = {} +for item in items: + key = process(item) + result[key] = item +''' + # New assignments - 'result' should replace top-level, 'key' should NOT replace loop var + new_assignments = { + "result": cst.parse_statement("result = {'new': 'dict'}").body[0], + "key": cst.parse_statement("key = 'new_value'").body[0], + } + + module = cst.parse_module(original_code) + transformer = GlobalAssignmentTransformer(new_assignments, ["result", "key"]) + result_module = module.visit(transformer) + + result_code = result_module.code + + # The 'key' inside the for-loop should NOT be replaced + assert "key = process(item)" in result_code, "Assignment inside for-loop should not be transformed" + + # The top-level 'result' SHOULD be replaced + assert "result = {'new': 'dict'}" in result_code, "Top-level assignment should be replaced" + assert "result = {}" not in result_code, "Original top-level assignment should be gone" + + +def test_add_global_assignments_with_loop_variables() -> None: + """Test that add_global_assignments doesn't extract assignments that reference loop variables. + + This is the specific bug case from optimization runs where code like: + for uval in double_quotes.values(): + ch = unicode_to_char(uval) + _translation_map[ord(ch)] = _double_quote_ord + + Would have 'ch = unicode_to_char(uval)' extracted and inserted at module level, + causing 'NameError: name 'uval' is not defined'. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original simple function + original_code = '''def standardize_quotes(text: str) -> str: + return text + +def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimized code with for-loop that builds translation table at module level + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) + +double_quotes = {"a": "U+0022", "b": "U+201C"} +_translation_map = {} + +for uval in double_quotes.values(): + ch = unicode_to_char(uval) + _translation_map[ord(ch)] = ord('"') + +def standardize_quotes(text: str) -> str: + return text.translate(_translation_map) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when compiled + # If 'ch = unicode_to_char(uval)' was incorrectly extracted, it would cause + # NameError because 'uval' wouldn't be defined outside the for-loop + try: + compile(result, "", "exec") + except NameError as e: + raise AssertionError(f"Generated code has NameError (loop var extracted incorrectly): {e}") from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the for-loop structure is preserved (ch assignment inside loop) + assert "for uval in" in result or "for unicode_val in" in result or "ch" not in result, ( + "If ch is in result, the for-loop should also be present" + ) + + +def test_add_global_assignments_with_unicode_val_loop_variable() -> None: + """Test that add_global_assignments correctly transfers for-loops using 'unicode_val' as loop variable. + + This is the specific bug case: NameError: name 'unicode_val' is not defined + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''def standardize_quotes(text: str) -> str: + return text +''' + + # Optimized code with for-loop using unicode_val as the loop variable + optimized_code = '''single_quotes = {"U+0027": "'", "U+2018": "'", "U+2019": "'"} +_translation_map = {} + +for unicode_val in single_quotes: + ch = chr(int(unicode_val.replace("U+", ""), 16)) + _translation_map[ord(ch)] = ord("'") + +def standardize_quotes(text: str) -> str: + return text.translate(_translation_map) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when compiled + try: + compile(result, "", "exec") + except NameError as e: + raise AssertionError(f"Generated code has NameError (unicode_val extracted incorrectly): {e}") from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the for-loop is present + assert "for unicode_val in" in result, "For-loop with unicode_val should be transferred" + + +def test_add_global_assignments_with_ch_loop_variable() -> None: + """Test that add_global_assignments correctly transfers for-loops using 'ch' as loop variable. + + This is the specific bug case: NameError: name 'ch' is not defined + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''def normalize_text(text: str) -> str: + return text +''' + + # Optimized code with for-loop using ch as the loop variable + optimized_code = '''replacements = {"a": "A", "b": "B"} +_char_map = {} + +for ch in replacements: + _char_map[ord(ch)] = ord(replacements[ch]) + +def normalize_text(text: str) -> str: + return text.translate(_char_map) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when compiled + try: + compile(result, "", "exec") + except NameError as e: + raise AssertionError(f"Generated code has NameError (ch extracted incorrectly): {e}") from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the for-loop is present + assert "for ch in" in result, "For-loop with ch should be transferred" + + +def test_add_global_assignments_with_helper_function_call() -> None: + """Test that add_global_assignments transfers assignments that call helper functions. + + Note: add_global_assignments only handles assignments, not function definitions. + Function definitions are transferred via replace_functions_in_file separately. + + This test verifies that assignments calling helper functions are correctly transferred. + The actual NameError: name '_build_quote_translation_table' is not defined would occur + if the function wasn't transferred in the full replacement flow. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original code that already has the helper function defined + original_code = '''def _build_quote_translation_table(): + table = {} + for i in range(128): + table[i] = i + return table + +def standardize_quotes(text: str) -> str: + return text +''' + + # Optimized code with an assignment that calls the helper function + optimized_code = '''def _build_quote_translation_table(): + table = {} + for i in range(128): + table[i] = i + return table + +_QUOTE_TABLE = _build_quote_translation_table() + +def standardize_quotes(text: str) -> str: + return text.translate(_QUOTE_TABLE) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when compiled + try: + compile(result, "", "exec") + except NameError as e: + raise AssertionError(f"Generated code has NameError: {e}") from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the assignment is present and placed AFTER the function definition + assert "_QUOTE_TABLE = _build_quote_translation_table()" in result, "Assignment should be transferred" + # Verify the function is still present (it was already in original) + assert "def _build_quote_translation_table" in result, "Helper function should be preserved" + + # Verify correct ordering: function must come before the assignment that uses it + func_pos = result.index("def _build_quote_translation_table") + assign_pos = result.index("_QUOTE_TABLE = _build_quote_translation_table()") + assert func_pos < assign_pos, "Function definition must come before assignment that calls it" From b234b831a06cacbc1f62d2907eafe3649d56d1ae Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 19 Jan 2026 00:54:06 -0500 Subject: [PATCH 06/14] fix: insert compound statements at end of module to avoid NameError For-loops and other compound statements at module level may call functions defined later in the file. Previously, they were inserted after imports, causing NameError when the called function was defined after the for-loop. Now: - Simple statements (assignments) are inserted after imports - Compound statements (for/while/with/try) are inserted at the END of the module This fixes NameError cases like 'unicode_to_char is not defined' when LLM-generated optimizations use module-level for-loops to build translation tables that call helper functions. --- codeflash/code_utils/code_extractor.py | 51 ++++--- tests/test_code_replacement.py | 180 +++++++++++++++++++++++++ 2 files changed, 214 insertions(+), 17 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index c3bf18e45..02e875c8c 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -479,9 +479,11 @@ def visit_Try(self, node: cst.Try) -> None: class ImportInserter(cst.CSTTransformer): - """Transformer that inserts global statements after the last import. + """Transformer that inserts global statements into a module. - Handles both simple statements and compound statements (for, while, with, try). + - Simple statements (assignments, etc.) are inserted after imports + - Compound statements (for, while, with, try) are inserted at the END of the module + because they may call functions that are defined later in the file """ def __init__(self, global_statements: list[cst.BaseStatement], last_import_line: int) -> None: @@ -493,23 +495,38 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c if not self.global_statements: return updated_node - # If no imports, insert at the beginning - if self.last_import_line == 0: - updated_body = list(self.global_statements) + list(updated_node.body) - return updated_node.with_changes(body=updated_body) - - # Find the insertion point: after the last import line - # last_import_line is 1-indexed, so compare with i + 1 - insertion_index = 0 - for i in range(len(updated_node.body)): - if i + 1 == self.last_import_line: - insertion_index = i + 1 - break + # Separate simple statements from compound statements + simple_statements = [] + compound_statements = [] + for stmt in self.global_statements: + if isinstance(stmt, cst.SimpleStatementLine): + simple_statements.append(stmt) + else: + # For, While, With, Try, If, etc. are compound statements + compound_statements.append(stmt) - # Insert global statements at the insertion point updated_body = list(updated_node.body) - for j, stmt in enumerate(self.global_statements): - updated_body.insert(insertion_index + j, stmt) + + # Insert simple statements after imports (or at beginning if no imports) + if simple_statements: + if self.last_import_line == 0: + # No imports, insert at the beginning + for j, stmt in enumerate(simple_statements): + updated_body.insert(j, stmt) + else: + # Find insertion point after last import + insertion_index = 0 + for i in range(len(updated_body)): + if i + 1 == self.last_import_line: + insertion_index = i + 1 + break + for j, stmt in enumerate(simple_statements): + updated_body.insert(insertion_index + j, stmt) + + # Append compound statements at the END of the module + # This ensures they run after all function definitions + if compound_statements: + updated_body.extend(compound_statements) return updated_node.with_changes(body=updated_body) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index fb3e913af..46a4fb084 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -4148,3 +4148,183 @@ def standardize_quotes(text: str) -> str: func_pos = result.index("def _build_quote_translation_table") assign_pos = result.index("_QUOTE_TABLE = _build_quote_translation_table()") assert func_pos < assign_pos, "Function definition must come before assignment that calls it" + + +def test_add_global_assignments_forloop_calls_function_defined_later() -> None: + """Test that for-loops calling functions are placed AFTER those function definitions. + + This is the specific bug case: NameError: name 'unicode_to_char' is not defined + + When the original file has a function defined later (e.g., after standardize_quotes), + and the optimized code adds a for-loop that calls that function, the for-loop must + be placed AFTER the function definition, not at the top of the file. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original code where unicode_to_char is defined AFTER the main function + # This mirrors the real-world case where the helper is at the bottom + original_code = '''def standardize_quotes(text: str) -> str: + return text + + +def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimized code with a for-loop that calls unicode_to_char at module level + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) + +_DOUBLE_QUOTE_UNICODE_VALUES = ["U+0022", "U+201C", "U+201D"] +_TRANSLATION_TABLE = {} + +for code in _DOUBLE_QUOTE_UNICODE_VALUES: + ch = unicode_to_char(code) + _TRANSLATION_TABLE[ord(ch)] = ord('"') + +def standardize_quotes(text: str) -> str: + return text.translate(_TRANSLATION_TABLE) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when executed + try: + compile(result, "", "exec") + # Also try to actually execute it to catch runtime NameErrors + exec(compile(result, "", "exec"), {}) + except NameError as e: + raise AssertionError( + f"Generated code has NameError (for-loop placed before function): {e}\n\nGenerated code:\n{result}" + ) from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the for-loop is present + assert "for code in _DOUBLE_QUOTE_UNICODE_VALUES" in result, "For-loop should be transferred" + + # Verify correct ordering: function must come before the for-loop that calls it + func_pos = result.index("def unicode_to_char") + forloop_pos = result.index("for code in _DOUBLE_QUOTE_UNICODE_VALUES") + assert func_pos < forloop_pos, ( + f"Function definition must come before for-loop that calls it.\n" + f"Function at position {func_pos}, for-loop at position {forloop_pos}\n" + f"Generated code:\n{result}" + ) + + +def test_add_global_assignments_forloop_uses_computed_variable() -> None: + """Test that for-loops are placed after variables they depend on. + + This is the specific bug case: NameError: name '_double_chars' is not defined + + When optimized code computes a variable and then uses it in a for-loop, + the assignment must come before the for-loop. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''def process_text(text: str) -> str: + return text +''' + + # Optimized code where _double_chars is computed and then used in a for-loop + optimized_code = '''_UNICODE_VALUES = ["U+0022", "U+201C"] +_double_chars = tuple(chr(int(u.replace("U+", ""), 16)) for u in _UNICODE_VALUES) + +_TRANSLATION = {} +for ch in _double_chars: + _TRANSLATION[ord(ch)] = ord('"') + +def process_text(text: str) -> str: + return text.translate(_TRANSLATION) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when executed + try: + compile(result, "", "exec") + exec(compile(result, "", "exec"), {}) + except NameError as e: + raise AssertionError( + f"Generated code has NameError (variable not defined before for-loop): {e}\n\nGenerated code:\n{result}" + ) from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify the for-loop is present + assert "for ch in _double_chars" in result, "For-loop should be transferred" + + # Verify correct ordering: _double_chars assignment must come before the for-loop + assign_pos = result.index("_double_chars = ") + forloop_pos = result.index("for ch in _double_chars") + assert assign_pos < forloop_pos, ( + f"Variable assignment must come before for-loop that uses it.\n" + f"Assignment at position {assign_pos}, for-loop at position {forloop_pos}\n" + f"Generated code:\n{result}" + ) + + +def test_add_global_assignments_multiple_forloops_with_dependencies() -> None: + """Test that multiple for-loops with function dependencies are ordered correctly. + + This tests the real-world case from standardize_quotes optimization where: + 1. unicode_to_char function is used + 2. Multiple for-loops build translation tables + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original code with function defined after main function + original_code = '''def standardize_quotes(text: str) -> str: + return text + + +def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimized code with multiple for-loops that depend on unicode_to_char + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + return chr(int(unicode_val.replace("U+", ""), 16)) + +double_quotes = {"U+0022": '"', "U+201C": '"'} +single_quotes = {"U+0027": "'", "U+2018": "'"} + +_translation_table = {} + +for unicode_val in double_quotes: + ch = unicode_to_char(unicode_val) + _translation_table[ord(ch)] = ord('"') + +for unicode_val in single_quotes: + ch = unicode_to_char(unicode_val) + _translation_table[ord(ch)] = ord("'") + +def standardize_quotes(text: str) -> str: + return text.translate(_translation_table) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when executed + try: + compile(result, "", "exec") + exec(compile(result, "", "exec"), {}) + except NameError as e: + raise AssertionError( + f"Generated code has NameError: {e}\n\nGenerated code:\n{result}" + ) from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify both for-loops are present + assert "for unicode_val in double_quotes" in result, "First for-loop should be transferred" + assert "for unicode_val in single_quotes" in result, "Second for-loop should be transferred" + + # Verify correct ordering: function must come before all for-loops + func_pos = result.index("def unicode_to_char") + first_forloop_pos = result.index("for unicode_val in double_quotes") + second_forloop_pos = result.index("for unicode_val in single_quotes") + + assert func_pos < first_forloop_pos, "Function must come before first for-loop" + assert func_pos < second_forloop_pos, "Function must come before second for-loop" From 206fa1cc836f4f89fbceb521a78fabf15e00293a Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 19 Jan 2026 01:11:16 -0500 Subject: [PATCH 07/14] fix: transfer new function definitions in add_global_assignments When LLM-generated optimizations introduce new helper functions that are called at module level, these functions must be transferred to the original code. Added FunctionDefCollector, FunctionNameCollector, and FunctionDefInserter to detect and transfer new function definitions before assignments that depend on them. --- codeflash/code_utils/code_extractor.py | 107 +++++ tests/test_code_replacement.py | 635 +++++++++++++++++++++++++ 2 files changed, 742 insertions(+) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 02e875c8c..3205dd2f7 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -314,6 +314,56 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c return updated_node.with_changes(body=new_statements) +class FunctionDefCollector(cst.CSTVisitor): + """Collects all top-level function definitions from a module. + + Used to find new helper functions in optimized code that need to be transferred. + """ + + def __init__(self) -> None: + super().__init__() + self.function_defs: dict[str, cst.FunctionDef] = {} + self.function_order: list[str] = [] + self.in_class = False + + def visit_ClassDef(self, node: cst.ClassDef) -> bool: # noqa: ARG002 + self.in_class = True + return False # Don't visit inside classes + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002 + self.in_class = False + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + if not self.in_class: + name = node.name.value + self.function_defs[name] = node + self.function_order.append(name) + return False # Don't visit inside functions + + +class FunctionNameCollector(cst.CSTVisitor): + """Collects all top-level function and class names from a module.""" + + def __init__(self) -> None: + super().__init__() + self.names: set[str] = set() + self.in_class = False + + def visit_ClassDef(self, node: cst.ClassDef) -> bool: + if not self.in_class: + self.names.add(node.name.value) + self.in_class = True + return False + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002 + self.in_class = False + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + if not self.in_class: + self.names.add(node.name.value) + return False + + class GlobalStatementCollector(cst.CSTVisitor): """Visitor that collects all global statements (excluding imports and functions/classes). @@ -478,6 +528,41 @@ def visit_Try(self, node: cst.Try) -> None: self._collect_imports_from_block(node.body) +class FunctionDefInserter(cst.CSTTransformer): + """Transformer that inserts new function definitions into a module. + + New function definitions are inserted after imports but before any code that depends on them. + """ + + def __init__(self, function_defs: list[cst.FunctionDef], last_import_line: int) -> None: + super().__init__() + self.function_defs = function_defs + self.last_import_line = last_import_line + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 + if not self.function_defs: + return updated_node + + updated_body = list(updated_node.body) + + # Insert function definitions after imports (or at beginning if no imports) + if self.last_import_line == 0: + # No imports, insert at the beginning + for j, func_def in enumerate(self.function_defs): + updated_body.insert(j, func_def) + else: + # Find insertion point after last import + insertion_index = 0 + for i in range(len(updated_body)): + if i + 1 == self.last_import_line: + insertion_index = i + 1 + break + for j, func_def in enumerate(self.function_defs): + updated_body.insert(insertion_index + j, func_def) + + return updated_node.with_changes(body=updated_body) + + class ImportInserter(cst.CSTTransformer): """Transformer that inserts global statements into a module. @@ -574,6 +659,20 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: src_module, new_added_global_statements = extract_global_statements(src_module_code) dst_module, existing_global_statements = extract_global_statements(dst_module_code) + # Collect function definitions from source and destination + src_func_collector = FunctionDefCollector() + src_module.visit(src_func_collector) + + dst_name_collector = FunctionNameCollector() + dst_module.visit(dst_name_collector) + + # Find new function definitions that don't exist in destination + new_function_defs: list[cst.FunctionDef] = [ + src_func_collector.function_defs[func_name] + for func_name in src_func_collector.function_order + if func_name not in dst_name_collector.names + ] + unique_global_statements = [] for stmt in new_added_global_statements: if any( @@ -597,6 +696,14 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: # No new statements to insert, reuse already-parsed dst_module original_module = dst_module + # Insert new function definitions if any + if new_function_defs: + last_import_line = find_last_import_line(mod_dst_code) + transformer = FunctionDefInserter(new_function_defs, last_import_line) + modified_module = original_module.visit(transformer) + mod_dst_code = modified_module.code + original_module = cst.parse_module(mod_dst_code) + # Parse the src_module_code once only (already done above: src_module) # Collect assignments from the new file new_collector = GlobalAssignmentCollector() diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 46a4fb084..b5c52acb1 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -4328,3 +4328,638 @@ def standardize_quotes(text: str) -> str: assert func_pos < first_forloop_pos, "Function must come before first for-loop" assert func_pos < second_forloop_pos, "Function must come before second for-loop" + + +# ============================================================================= +# Real-world standardize_quotes optimization tests +# These tests verify the fixes work for the actual optimization scenarios +# ============================================================================= + + +def test_standardize_quotes_optimization_candidate_6_pattern() -> None: + """Test optimization pattern from candidate 6 that caused NameError: name '_double_chars' is not defined. + + This pattern uses tuple comprehensions at module level that depend on unicode_to_char, + then uses those tuples in for-loops to build translation tables. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original standardize_quotes code structure + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + double_quotes = { + '"': "U+0022", + '"': "U+201C", + '"': "U+201D", + } + single_quotes = { + "'": "U+0027", + "'": "U+2018", + "'": "U+2019", + } + + double_quote_standard = '"' + single_quote_standard = "'" + + for unicode_val in double_quotes.values(): + unicode_char = unicode_to_char(unicode_val) + if unicode_char in text: + text = text.replace(unicode_char, double_quote_standard) + + for unicode_val in single_quotes.values(): + unicode_char = unicode_to_char(unicode_val) + if unicode_char in text: + text = text.replace(unicode_char, single_quote_standard) + + return text + + +def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimization candidate 6 pattern: precompute chars using tuple comprehension + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +double_quotes = { + '"': "U+0022", + '"': "U+201C", + '"': "U+201D", +} +single_quotes = { + "'": "U+0027", + "'": "U+2018", + "'": "U+2019", +} + +_double_chars = tuple(unicode_to_char(u) for u in double_quotes.values()) +_single_chars = tuple(unicode_to_char(u) for u in single_quotes.values()) + +_QUOTE_TRANSLATION: dict[int, str] = {} +_double_quote_standard = '"' +_single_quote_standard = "'" + +for _ch in _double_chars: + _QUOTE_TRANSLATION[ord(_ch)] = _double_quote_standard + +for _ch in _single_chars: + _QUOTE_TRANSLATION[ord(_ch)] = _single_quote_standard + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_QUOTE_TRANSLATION) +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + exec(compiled, {}) + except NameError as e: + raise AssertionError( + f"Candidate 6 pattern failed with NameError: {e}\n\nGenerated code:\n{result}" + ) from e + + # Verify key elements are present + assert "_double_chars = tuple" in result, "_double_chars assignment should be present" + assert "for _ch in _double_chars" in result, "First for-loop should be present" + assert "for _ch in _single_chars" in result, "Second for-loop should be present" + + # Verify ordering: unicode_to_char must come before _double_chars assignment + func_pos = result.index("def unicode_to_char") + double_chars_pos = result.index("_double_chars = tuple") + assert func_pos < double_chars_pos, "unicode_to_char must be defined before _double_chars" + + +def test_standardize_quotes_optimization_candidate_7_pattern() -> None: + """Test optimization pattern from candidate 7 that caused NameError: name 'unicode_to_char' is not defined. + + This pattern moves quote dictionaries to module level and builds translation + table using for-loops that call unicode_to_char. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original code + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text + + +def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimization candidate 7 pattern: module-level for-loops calling unicode_to_char + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +_DOUBLE_QUOTE_UNICODE_VALUES = [ + "U+0022", + "U+201C", + "U+201D", +] + +_SINGLE_QUOTE_UNICODE_VALUES = [ + "U+0027", + "U+2018", + "U+2019", +] + +_QUOTE_TRANSLATION_TABLE = {} +_double_replacement_ord = ord('"') +_single_replacement_ord = ord("'") + +for u in _DOUBLE_QUOTE_UNICODE_VALUES: + ch = unicode_to_char(u) + _QUOTE_TRANSLATION_TABLE[ord(ch)] = _double_replacement_ord + +for u in _SINGLE_QUOTE_UNICODE_VALUES: + ch = unicode_to_char(u) + _QUOTE_TRANSLATION_TABLE[ord(ch)] = _single_replacement_ord + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_QUOTE_TRANSLATION_TABLE) +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + namespace = {} + exec(compiled, namespace) + + # Verify the translation table is correctly populated by the for-loops + # This verifies that the for-loops execute correctly and call unicode_to_char + translation_table = namespace.get("_QUOTE_TRANSLATION_TABLE") + assert translation_table is not None, "_QUOTE_TRANSLATION_TABLE should be defined" + # 8220 = U+201C (left double quote), 8221 = U+201D (right double quote) + # 34 = ord('"') = ASCII double quote + assert 8220 in translation_table, "Left double quote (U+201C) should be in table" + assert 8221 in translation_table, "Right double quote (U+201D) should be in table" + assert translation_table[8220] == 34, "Left double quote should map to ASCII quote" + assert translation_table[8221] == 34, "Right double quote should map to ASCII quote" + + except NameError as e: + raise AssertionError( + f"Candidate 7 pattern failed with NameError: {e}\n\nGenerated code:\n{result}" + ) from e + + # Verify for-loops are present and ordered correctly + assert "for u in _DOUBLE_QUOTE_UNICODE_VALUES" in result + assert "for u in _SINGLE_QUOTE_UNICODE_VALUES" in result + + func_pos = result.index("def unicode_to_char") + first_loop_pos = result.index("for u in _DOUBLE_QUOTE_UNICODE_VALUES") + assert func_pos < first_loop_pos, "unicode_to_char must be defined before for-loops" + + +def test_standardize_quotes_optimization_candidate_9_pattern() -> None: + """Test optimization pattern from candidate 9 that caused NameError: name 'unicode_to_char' is not defined. + + This pattern is similar to candidate 7 but with slightly different variable names. + It builds a translation table at module level using for-loops. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original code with unicode_to_char defined AFTER standardize_quotes + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + double_quotes = {'"': "U+0022", '"': "U+201C"} + single_quotes = {"'": "U+0027", "'": "U+2018"} + + for unicode_val in double_quotes.values(): + unicode_char = unicode_to_char(unicode_val) + if unicode_char in text: + text = text.replace(unicode_char, '"') + + for unicode_val in single_quotes.values(): + unicode_char = unicode_to_char(unicode_val) + if unicode_char in text: + text = text.replace(unicode_char, "'") + + return text + + +def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimization candidate 9 pattern + # Use proper Unicode escape sequences for the curly quote characters as keys + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +double_quotes = {'"': "U+0022", '\u201c': "U+201C", '\u201d': "U+201D"} +single_quotes = {"'": "U+0027", '\u2018': "U+2018", '\u2019': "U+2019"} + +_translation_table = {} + +_double_standard_ord = ord('"') +for unicode_val in double_quotes.values(): + ch = unicode_to_char(unicode_val) + _translation_table[ord(ch)] = _double_standard_ord + +_single_standard_ord = ord("'") +for unicode_val in single_quotes.values(): + ch = unicode_to_char(unicode_val) + _translation_table[ord(ch)] = _single_standard_ord + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_translation_table) +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + namespace = {} + exec(compiled, namespace) + + # Verify the translation table is correctly populated by the for-loops + # This verifies that the for-loops execute correctly and call unicode_to_char + translation_table = namespace.get("_translation_table") + assert translation_table is not None, "_translation_table should be defined" + # 8220 = U+201C (left double quote), 8221 = U+201D (right double quote) + # 8216 = U+2018 (left single quote), 8217 = U+2019 (right single quote) + assert 8220 in translation_table, "Left double quote (U+201C) should be in table" + assert 8221 in translation_table, "Right double quote (U+201D) should be in table" + assert 8216 in translation_table, "Left single quote (U+2018) should be in table" + assert 8217 in translation_table, "Right single quote (U+2019) should be in table" + # Verify mappings to ASCII quotes + assert translation_table[8220] == 34, "Left double quote should map to ASCII double quote" + assert translation_table[8221] == 34, "Right double quote should map to ASCII double quote" + assert translation_table[8216] == 39, "Left single quote should map to ASCII single quote" + assert translation_table[8217] == 39, "Right single quote should map to ASCII single quote" + + except NameError as e: + raise AssertionError( + f"Candidate 9 pattern failed with NameError: {e}\n\nGenerated code:\n{result}" + ) from e + + # Verify ordering: unicode_to_char must come before module-level for-loops + # Use \nfor to find module-level (unindented) for-loops, not those inside functions + func_pos = result.index("def unicode_to_char") + first_loop = result.index("\nfor unicode_val in double_quotes.values()") + second_loop = result.index("\nfor unicode_val in single_quotes.values()") + + assert func_pos < first_loop, "unicode_to_char must be defined before first for-loop" + assert func_pos < second_loop, "unicode_to_char must be defined before second for-loop" + assert first_loop < second_loop, "For-loops should maintain their relative order" + + +def test_standardize_quotes_testgen_postprocessing_with_translation_table() -> None: + """Test end-to-end postprocessing pipeline for standardize_quotes with translation table pattern. + + This simulates the testgen endpoint returning generated tests for an optimization + that uses module-level for-loops to build a translation table, and verifies + the postprocessing pipeline (add_global_assignments + test postprocessing) works correctly. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests + from codeflash.models.models import GeneratedTests, GeneratedTestsList + + # Original code (what we're optimizing) + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text + + +def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Simulated testgen optimization response with module-level for-loops + # This pattern builds a translation table at module level + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +_DOUBLE_QUOTE_CODES = ["U+0022", "U+201C", "U+201D"] +_SINGLE_QUOTE_CODES = ["U+0027", "U+2018", "U+2019"] + +_QUOTE_TABLE = {} + +for _code in _DOUBLE_QUOTE_CODES: + _ch = unicode_to_char(_code) + _QUOTE_TABLE[ord(_ch)] = ord('"') + +for _code in _SINGLE_QUOTE_CODES: + _ch = unicode_to_char(_code) + _QUOTE_TABLE[ord(_ch)] = ord("'") + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_QUOTE_TABLE) +''' + + # Simulated generated test code + generated_test_source = '''import pytest +from module import standardize_quotes + +def test_standardize_quotes_basic(): + """Test basic quote standardization.""" + result = standardize_quotes("Hello world") + assert result == "Hello world" + +def test_standardize_quotes_unicode_double(): + """Test unicode double quote conversion.""" + result = standardize_quotes("Say \\u201chello\\u201d") + assert '"' in result + +def test_standardize_quotes_empty(): + """Test empty string.""" + result = standardize_quotes("") + assert result == "" +''' + + # Step 1: Process optimization code through add_global_assignments + processed_optimization = add_global_assignments(optimized_code, original_code) + + # Verify optimization code compiles and executes without NameError + try: + compiled = compile(processed_optimization, "", "exec") + namespace = {} + exec(compiled, namespace) + + # Verify the translation table is populated + quote_table = namespace.get("_QUOTE_TABLE") + assert quote_table is not None, "_QUOTE_TABLE should be defined" + assert 8220 in quote_table, "Left double quote (U+201C) should be in table" + assert 8221 in quote_table, "Right double quote (U+201D) should be in table" + except NameError as e: + raise AssertionError(f"Optimization code failed with NameError: {e}\n\n{processed_optimization}") from e + + # Step 2: Process generated tests through remove_functions_from_generated_tests + generated_test = GeneratedTests( + generated_original_test_source=generated_test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_standardize_quotes.py"), + perf_file_path=Path("test_standardize_quotes_perf.py"), + ) + generated_tests_list = GeneratedTestsList(generated_tests=[generated_test]) + + # Simulate removing a failed test + processed_tests = remove_functions_from_generated_tests( + generated_tests_list, ["test_standardize_quotes_empty"] + ) + + # Verify the test was removed and others remain + result_source = processed_tests.generated_tests[0].generated_original_test_source + assert "test_standardize_quotes_basic" in result_source + assert "test_standardize_quotes_unicode_double" in result_source + assert "test_standardize_quotes_empty" not in result_source + + +def test_standardize_quotes_testgen_postprocessing_with_dict_comprehension() -> None: + """Test postprocessing pipeline for standardize_quotes with dict comprehension pattern. + + This simulates an optimization that uses dict comprehension with calls to + a helper function, ensuring the global assignments are correctly ordered. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + from codeflash.code_utils.edit_generated_tests import add_runtime_comments_to_generated_tests + from codeflash.models.models import GeneratedTests, GeneratedTestsList + + # Original code + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text + + +def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) +''' + + # Optimization using dict comprehension (depends on unicode_to_char being defined first) + optimized_code = '''def unicode_to_char(unicode_val: str) -> str: + """Converts a Unicode value to a character.""" + return chr(int(unicode_val.replace("U+", ""), 16)) + +_DOUBLE_UNICODES = {"U+0022": '"', "U+201C": '"', "U+201D": '"'} +_SINGLE_UNICODES = {"U+0027": "'", "U+2018": "'", "U+2019": "'"} + +# Build translation table using comprehension +_TRANSLATION = { + ord(unicode_to_char(code)): ord(target) + for mapping in [_DOUBLE_UNICODES, _SINGLE_UNICODES] + for code, target in mapping.items() +} + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_TRANSLATION) +''' + + # Process optimization code + processed_optimization = add_global_assignments(optimized_code, original_code) + + # Verify code compiles and translation dict is built + try: + compiled = compile(processed_optimization, "", "exec") + namespace = {} + exec(compiled, namespace) + + translation = namespace.get("_TRANSLATION") + assert translation is not None, "_TRANSLATION should be defined" + # Verify the mapping contains unicode quote codes + assert len(translation) >= 4, "Translation table should have at least 4 entries" + except NameError as e: + raise AssertionError(f"Dict comprehension pattern failed with NameError: {e}") from e + + # Create mock generated tests with runtime data + generated_test_source = '''def test_standardize_double_quotes(): + result = standardize_quotes("\\u201chello\\u201d") + assert result == '"hello"' + +def test_standardize_single_quotes(): + result = standardize_quotes("\\u2018world\\u2019") + assert result == "'world'" +''' + + generated_test = GeneratedTests( + generated_original_test_source=generated_test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_quotes.py"), + perf_file_path=Path("test_quotes_perf.py"), + ) + generated_tests_list = GeneratedTestsList(generated_tests=[generated_test]) + + # Mock runtime data for add_runtime_comments_to_generated_tests + # (Empty dicts since we don't have actual runtime data in this test) + original_runtimes: dict = {} + optimized_runtimes: dict = {} + + # Process through runtime comments (should handle empty runtimes gracefully) + processed_tests = add_runtime_comments_to_generated_tests( + generated_tests_list, original_runtimes, optimized_runtimes + ) + + # Verify tests are still valid + result_source = processed_tests.generated_tests[0].generated_original_test_source + assert "test_standardize_double_quotes" in result_source + assert "test_standardize_single_quotes" in result_source + + +def test_standardize_quotes_testgen_full_pipeline_integration() -> None: + """Test complete integration of testgen postprocessing pipeline for standardize_quotes. + + This test simulates the full flow: + 1. Original code with the function to optimize + 2. LLM-generated optimization with module-level for-loops + 3. Processing through add_global_assignments + 4. Generated tests processing through remove_functions + add_runtime_comments + 5. Final verification that everything compiles and is correct + """ + from codeflash.code_utils.code_extractor import add_global_assignments + from codeflash.code_utils.edit_generated_tests import ( + add_runtime_comments_to_generated_tests, + remove_functions_from_generated_tests, + ) + from codeflash.models.models import GeneratedTests, GeneratedTestsList + + # Original FTO code + original_code = '''def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes. + + Args: + text: Input text that may contain unicode quotes. + + Returns: + Text with unicode quotes replaced by ASCII quotes. + """ + double_quotes = {'"': "U+0022", '\u201c': "U+201C", '\u201d': "U+201D"} + single_quotes = {"'": "U+0027", '\u2018': "U+2018", '\u2019': "U+2019"} + + for char, unicode_val in double_quotes.items(): + if char != '"': + text = text.replace(char, '"') + + for char, unicode_val in single_quotes.items(): + if char != "'": + text = text.replace(char, "'") + + return text +''' + + # LLM-generated optimization with a NEW helper function and module-level call + # This tests that add_global_assignments transfers new function definitions + optimized_code = '''def _build_translation_table() -> dict[int, int]: + """Build translation table for quote standardization.""" + table = {} + # Double quotes + for code in [0x0022, 0x201C, 0x201D]: + table[code] = 0x0022 # Map to ASCII double quote + # Single quotes + for code in [0x0027, 0x2018, 0x2019]: + table[code] = 0x0027 # Map to ASCII single quote + return table + +_QUOTE_TRANSLATION_TABLE = _build_translation_table() + + +def standardize_quotes(text: str) -> str: + """Converts all unicode quotes to standard ASCII quotes.""" + return text.translate(_QUOTE_TRANSLATION_TABLE) +''' + + # Step 1: Process optimization through add_global_assignments + processed_code = add_global_assignments(optimized_code, original_code) + + # Verify optimization code structure - new function should be transferred + assert "_build_translation_table" in processed_code, "New helper function should be present" + assert "_QUOTE_TRANSLATION_TABLE = _build_translation_table()" in processed_code + + # Verify code executes without errors + namespace = {} + exec(compile(processed_code, "", "exec"), namespace) + + # Verify the translation table is correctly built + table = namespace["_QUOTE_TRANSLATION_TABLE"] + assert table[0x201C] == 0x0022, "Left double quote should map to ASCII double quote" + assert table[0x201D] == 0x0022, "Right double quote should map to ASCII double quote" + assert table[0x2018] == 0x0027, "Left single quote should map to ASCII single quote" + assert table[0x2019] == 0x0027, "Right single quote should map to ASCII single quote" + + # Step 2: Create mock generated tests + generated_test_source = '''import pytest + +def test_standardize_quotes_no_change(): + """Test text without unicode quotes.""" + result = standardize_quotes("Hello, World!") + assert result == "Hello, World!" + +def test_standardize_quotes_double_unicode(): + """Test double quote unicode conversion.""" + text = "She said \\u201cHello\\u201d" + result = standardize_quotes(text) + assert result == 'She said "Hello"' + +def test_standardize_quotes_single_unicode(): + """Test single quote unicode conversion.""" + text = "It\\u2019s a test" + result = standardize_quotes(text) + assert result == "It's a test" + +def test_standardize_quotes_mixed(): + """Test mixed quote types.""" + text = "\\u201cIt\\u2019s \\u2018great\\u2019\\u201d" + result = standardize_quotes(text) + assert result == '"It\\'s \\'great\\'\"' + +def test_standardize_quotes_failing(): + """This test will be removed as failed.""" + result = standardize_quotes("test") + assert result == "wrong" +''' + + generated_test = GeneratedTests( + generated_original_test_source=generated_test_source, + instrumented_behavior_test_source="instrumented_behavior_placeholder", + instrumented_perf_test_source="instrumented_perf_placeholder", + behavior_file_path=Path("tests/test_standardize_quotes__unit_test_0.py"), + perf_file_path=Path("tests/test_standardize_quotes__perf_test_0.py"), + ) + generated_tests_list = GeneratedTestsList(generated_tests=[generated_test]) + + # Step 3: Remove failed tests + tests_to_remove = ["test_standardize_quotes_failing"] + processed_tests = remove_functions_from_generated_tests(generated_tests_list, tests_to_remove) + + result_source = processed_tests.generated_tests[0].generated_original_test_source + assert "test_standardize_quotes_no_change" in result_source + assert "test_standardize_quotes_double_unicode" in result_source + assert "test_standardize_quotes_single_unicode" in result_source + assert "test_standardize_quotes_mixed" in result_source + assert "test_standardize_quotes_failing" not in result_source + + # Step 4: Add runtime comments (with empty data to test graceful handling) + final_tests = add_runtime_comments_to_generated_tests(processed_tests, {}, {}) + + # Verify final output is still valid Python + final_source = final_tests.generated_tests[0].generated_original_test_source + compile(final_source, "", "exec") # Should not raise + + # Count remaining test functions + test_count = final_source.count("def test_") + assert test_count == 4, f"Should have 4 tests after removing failed one, got {test_count}" From 6936685d9fefcd5980b0b9ff174e6c294b0c8dae Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 19 Jan 2026 07:17:54 -0500 Subject: [PATCH 08/14] pass args --- codeflash/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/main.py b/codeflash/main.py index 31afd0305..67880a7fc 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -32,7 +32,7 @@ def main() -> None: disable_telemetry = pyproject_config.get("disable_telemetry", False) init_sentry(not disable_telemetry, exclude_errors=True) posthog_cf.initialize_posthog(not disable_telemetry) - args.func() + args.func(args) elif args.verify_setup: args = process_pyproject_config(args) init_sentry(not args.disable_telemetry, exclude_errors=True) From 62291d76df92016aa1a68812aa4e1dedf4a8f1eb Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 19 Jan 2026 07:24:23 -0500 Subject: [PATCH 09/14] fix: set default git-remote to origin in create-pr command --- codeflash/cli_cmds/cli.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 2de7a804d..4da82bc6e 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -70,7 +70,9 @@ def parse_args() -> Namespace: create_pr_parser.add_argument( "--function", type=str, help="Function name (required if multiple functions in results)" ) - create_pr_parser.add_argument("--git-remote", type=str, help="Git remote to use for PR creation (default: origin)") + create_pr_parser.add_argument( + "--git-remote", type=str, default="origin", help="Git remote to use for PR creation (default: origin)" + ) parser.add_argument("--file", help="Try to optimize only this file") parser.add_argument("--function", help="Try to optimize only this function within the given file path") From a3f889b7fd96141f45a77ddb3256499ea5056727 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 19 Jan 2026 10:37:07 -0500 Subject: [PATCH 10/14] feat: store precomputed test reports for create-pr command Capture test report summary and loop count during phase1 optimization so the create-pr CLI can generate PR comments without re-running tests. --- codeflash/cli_cmds/cmd_create_pr.py | 2 ++ codeflash/github/PrComment.py | 24 +++++++++++++++----- codeflash/models/models.py | 3 +++ codeflash/optimization/function_optimizer.py | 11 +++++++++ codeflash/result/create_pr.py | 6 +++++ pyproject.toml | 1 + uv.lock | 2 ++ 7 files changed, 43 insertions(+), 6 deletions(-) diff --git a/codeflash/cli_cmds/cmd_create_pr.py b/codeflash/cli_cmds/cmd_create_pr.py index d5ca3946f..902c6ff8e 100644 --- a/codeflash/cli_cmds/cmd_create_pr.py +++ b/codeflash/cli_cmds/cmd_create_pr.py @@ -138,6 +138,8 @@ def create_pr(args: Namespace) -> None: optimization_review="", root_dir=git_root_dir(), git_remote=getattr(args, "git_remote", None), + precomputed_test_report=func_result.test_report, + precomputed_loop_count=func_result.loop_count, ) # Cleanup results file after successful PR creation diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index fe0ff095e..3a1021d54 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -23,13 +23,25 @@ class PrComment: benchmark_details: Optional[list[BenchmarkDetail]] = None original_async_throughput: Optional[int] = None best_async_throughput: Optional[int] = None + # Optional pre-computed values (used by create-pr CLI command) + precomputed_test_report: Optional[dict[str, dict[str, int]]] = None + precomputed_loop_count: Optional[int] = None def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]: - report_table = { - test_type.to_name(): result - for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items() - if test_type.to_name() - } + # Use precomputed values if available, otherwise compute from TestResults + if self.precomputed_test_report is not None: + report_table = self.precomputed_test_report + else: + report_table = { + test_type.to_name(): result + for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items() + if test_type.to_name() + } + loop_count = ( + self.precomputed_loop_count + if self.precomputed_loop_count is not None + else self.winning_benchmarking_test_results.number_of_loops() + ) result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = { "optimization_explanation": self.optimization_explanation, @@ -39,7 +51,7 @@ def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[B "file_path": self.relative_file_path, "speedup_x": self.speedup_x, "speedup_pct": self.speedup_pct, - "loop_count": self.winning_benchmarking_test_results.number_of_loops(), + "loop_count": loop_count, "report_table": report_table, "benchmark_details": self.benchmark_details if self.benchmark_details else None, } diff --git a/codeflash/models/models.py b/codeflash/models/models.py index d886aa633..016ce73a0 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -945,6 +945,9 @@ class Phase1FunctionResult(BaseModel): concolic_tests_source: Optional[str] = None best_candidate_explanation: Optional[str] = None best_runtime_ns: Optional[int] = None + # Test results summary for PR creation + test_report: Optional[dict[str, dict[str, int]]] = None # test_type_name -> {passed: int, failed: int} + loop_count: Optional[int] = None class Phase1Output(BaseModel): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 54d67220f..381561d6c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -486,6 +486,8 @@ def __init__( self.phase1_concolic_tests: str | None = None self.phase1_best_explanation: str | None = None self.phase1_best_runtime: int | None = None + self.phase1_test_report: dict[str, dict[str, int]] | None = None + self.phase1_loop_count: int | None = None def load_augmented_prompts(self) -> AugmentedPrompts | None: if not self.augmented_prompt_file: @@ -549,6 +551,8 @@ def get_phase1_function_result( concolic_tests_source=self.phase1_concolic_tests, best_candidate_explanation=self.phase1_best_explanation, best_runtime_ns=self.phase1_best_runtime, + test_report=self.phase1_test_report, + loop_count=self.phase1_loop_count, ) def write_phase1_output(self, function_result: Phase1FunctionResult) -> None: @@ -2099,6 +2103,13 @@ def process_review( self.phase1_concolic_tests = concolic_tests self.phase1_best_explanation = new_explanation.explanation_message() self.phase1_best_runtime = best_optimization.runtime + # Capture test results for PR creation + self.phase1_test_report = { + tt.to_name(): counts + for tt, counts in new_explanation.winning_behavior_test_results.get_test_pass_fail_report_by_type().items() + if tt.to_name() # Skip empty names + } + self.phase1_loop_count = new_explanation.winning_benchmarking_test_results.number_of_loops() data = { "original_code": original_code_combined, diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index f888f710a..da3ca1a72 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -186,6 +186,8 @@ def check_create_pr( root_dir: Path, git_remote: Optional[str] = None, optimization_review: str = "", + precomputed_test_report: Optional[dict[str, dict[str, int]]] = None, + precomputed_loop_count: Optional[int] = None, ) -> None: pr_number: Optional[int] = env_utils.get_pr_number() git_repo = git.Repo(search_parent_directories=True) @@ -222,6 +224,8 @@ def check_create_pr( benchmark_details=explanation.benchmark_details, original_async_throughput=explanation.original_async_throughput, best_async_throughput=explanation.best_async_throughput, + precomputed_test_report=precomputed_test_report, + precomputed_loop_count=precomputed_loop_count, ), existing_tests=existing_tests_source, generated_tests=generated_original_test_source, @@ -274,6 +278,8 @@ def check_create_pr( benchmark_details=explanation.benchmark_details, original_async_throughput=explanation.original_async_throughput, best_async_throughput=explanation.best_async_throughput, + precomputed_test_report=precomputed_test_report, + precomputed_loop_count=precomputed_loop_count, ), existing_tests=existing_tests_source, generated_tests=generated_original_test_source, diff --git a/pyproject.toml b/pyproject.toml index bc9dc12db..9c6d8f312 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "codeflash-benchmark", "filelock", "pytest-asyncio>=1.2.0", + "pyyaml>=6.0.3", ] [project.urls] diff --git a/uv.lock b/uv.lock index 3012a779a..5c5d7df98 100644 --- a/uv.lock +++ b/uv.lock @@ -433,6 +433,7 @@ dependencies = [ { name = "pytest-asyncio", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "pytest-asyncio", version = "1.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "pytest-timeout" }, + { name = "pyyaml" }, { name = "rich" }, { name = "sentry-sdk" }, { name = "tomlkit" }, @@ -518,6 +519,7 @@ requires-dist = [ { name = "pytest", specifier = ">=7.0.0" }, { name = "pytest-asyncio", specifier = ">=1.2.0" }, { name = "pytest-timeout", specifier = ">=2.1.0" }, + { name = "pyyaml", specifier = ">=6.0.3" }, { name = "rich", specifier = ">=13.8.1" }, { name = "sentry-sdk", specifier = ">=1.40.6,<3.0.0" }, { name = "tomlkit", specifier = ">=0.11.7" }, From fe3bf4ed1a057e13815961cdd153ac0282a3f696 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 21 Jan 2026 02:20:39 -0500 Subject: [PATCH 11/14] fix: insert new global statements after the globals they depend on Previously, ImportInserter inserted new global statements right after imports, which caused NameError when those statements referenced other globals defined later in the file (e.g., `_CACHE: dict = {None: tbl}` being inserted before `tbl` was defined). Now analyzes dependencies in each new statement and inserts it after the last definition of any name it references. --- codeflash/code_utils/code_extractor.py | 73 +++++++++++++++++++++----- tests/test_code_replacement.py | 70 ++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 12 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 3205dd2f7..4500d050e 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -563,10 +563,51 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c return updated_node.with_changes(body=updated_body) +def get_statement_dependencies(stmt: cst.BaseStatement) -> set[str]: + """Extract all names that a statement depends on (names used on the RHS).""" + deps: set[str] = set() + if isinstance(stmt, cst.SimpleStatementLine): + for body_item in stmt.body: + if isinstance(body_item, cst.Assign): + # Get names used in the value (RHS) + deps.update(get_names_in_expression(body_item.value)) + elif isinstance(body_item, cst.AnnAssign) and body_item.value is not None: + # Get names used in the value and annotation + deps.update(get_names_in_expression(body_item.value)) + deps.update(get_names_in_expression(body_item.annotation.annotation)) + return deps + + +def get_statement_defined_name(stmt: cst.BaseStatement) -> str | None: + """Get the name defined by a statement (LHS of assignment).""" + if isinstance(stmt, cst.SimpleStatementLine) and len(stmt.body) == 1: + body_item = stmt.body[0] + if isinstance(body_item, cst.Assign) and len(body_item.targets) == 1: + target = body_item.targets[0].target + if isinstance(target, cst.Name): + return target.value + elif isinstance(body_item, cst.AnnAssign): + target = body_item.target + if isinstance(target, cst.Name): + return target.value + return None + + +def find_last_definition_index(name: str, body: list[cst.BaseStatement]) -> int: + """Find the index of the last statement that defines a given name.""" + last_idx = -1 + for i, stmt in enumerate(body): + defined_name = get_statement_defined_name(stmt) + if defined_name == name: + last_idx = i + return last_idx + + class ImportInserter(cst.CSTTransformer): """Transformer that inserts global statements into a module. - - Simple statements (assignments, etc.) are inserted after imports + - Simple statements (assignments, etc.) are inserted after the globals they depend on + - If no dependencies, they are inserted after imports - Compound statements (for, while, with, try) are inserted at the END of the module because they may call functions that are defined later in the file """ @@ -592,21 +633,29 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c updated_body = list(updated_node.body) - # Insert simple statements after imports (or at beginning if no imports) + # For simple statements, insert after their dependencies or after imports if simple_statements: - if self.last_import_line == 0: - # No imports, insert at the beginning - for j, stmt in enumerate(simple_statements): - updated_body.insert(j, stmt) - else: - # Find insertion point after last import - insertion_index = 0 + # Find the base insertion point (after imports) + base_insertion_index = 0 + if self.last_import_line > 0: for i in range(len(updated_body)): if i + 1 == self.last_import_line: - insertion_index = i + 1 + base_insertion_index = i + 1 break - for j, stmt in enumerate(simple_statements): - updated_body.insert(insertion_index + j, stmt) + + # Insert each statement at the correct position based on its dependencies + for stmt in simple_statements: + deps = get_statement_dependencies(stmt) + + # Find the position after the last dependency definition + insertion_index = base_insertion_index + for dep_name in deps: + dep_idx = find_last_definition_index(dep_name, updated_body) + if dep_idx >= 0: + # Insert after this dependency + insertion_index = max(insertion_index, dep_idx + 1) + + updated_body.insert(insertion_index, stmt) # Append compound statements at the END of the module # This ensures they run after all function definitions diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index b5c52acb1..bc13080a3 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -4330,6 +4330,76 @@ def standardize_quotes(text: str) -> str: assert func_pos < second_forloop_pos, "Function must come before second for-loop" +def test_add_global_assignments_variable_depends_on_existing_global() -> None: + """Test that new global assignments depending on existing globals are inserted after them. + + This tests the fix for a bug where LLM-generated optimizations that add module-level + cache variables like `_TRANSLATION_CACHE: dict = {None: tbl}` would fail because the + assignment was inserted after imports but BEFORE the `tbl` variable it depends on. + + Real-world example from unstructured/cleaners/core.py optimization: + - Original has: `tbl = dict.fromkeys(...)` + - Optimized adds: `_TRANSLATION_CACHE: dict = {None: tbl}` (depends on tbl) + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + # Original file with imports and module-level variable + original_code = '''from __future__ import annotations +import sys +import unicodedata + +tbl = dict.fromkeys( + i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P") +) + +def remove_sentence_punctuation(s: str) -> str: + tbl_new = tbl.copy() + return s.translate(tbl_new) +''' + + # Optimized code adds a cache that depends on `tbl` + optimized_code = '''from __future__ import annotations +import sys +import unicodedata + +tbl = dict.fromkeys( + i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P") +) + +# Cache for translation tables +_TRANSLATION_CACHE: dict = {None: tbl} + +def remove_sentence_punctuation(s: str) -> str: + tbl_new = tbl.copy() + return s.translate(tbl_new) +''' + + result = add_global_assignments(optimized_code, original_code) + + # The result should be valid Python - no NameError when executed + try: + compile(result, "", "exec") + exec(compile(result, "", "exec"), {}) + except NameError as e: + raise AssertionError( + f"Generated code has NameError: {e}\n\nGenerated code:\n{result}" + ) from e + except SyntaxError as e: + raise AssertionError(f"Generated code has SyntaxError: {e}") from e + + # Verify `_TRANSLATION_CACHE` is in the result + assert "_TRANSLATION_CACHE" in result, "_TRANSLATION_CACHE should be in the result" + + # Verify correct ordering: `tbl` must come before `_TRANSLATION_CACHE` + tbl_pos = result.index("tbl = dict.fromkeys") + cache_pos = result.index("_TRANSLATION_CACHE") + + assert cache_pos > tbl_pos, ( + f"_TRANSLATION_CACHE (pos {cache_pos}) should be after " + f"tbl (pos {tbl_pos}) because it depends on it" + ) + + # ============================================================================= # Real-world standardize_quotes optimization tests # These tests verify the fixes work for the actual optimization scenarios From 9850f9f22a19e6e50d357a6e9bace3679eceab89 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 21 Jan 2026 02:36:24 -0500 Subject: [PATCH 12/14] fix: handle tuple unpacking and chained assignments in dependency tracking - Add _extract_names_from_target() to handle tuple/list unpacking targets - Replace get_statement_defined_name() with get_statement_defined_names() returning a set to support multiple names per statement - Add _sort_statements_by_dependencies() for topological sorting of new statements before insertion - Fix GlobalStatementCollector to include tuple/chained assignments - Fix GlobalAssignmentTransformer to recognize tuple assignments in definition_positions - Remove type annotations from dependency tracking (they don't cause runtime NameErrors) - Add tests for tuple unpacking, chained assignments, multiple statements, and annotated assignments --- codeflash/code_utils/code_extractor.py | 133 +++++++++++++++---- tests/test_code_replacement.py | 171 +++++++++++++++++++++++++ 2 files changed, 280 insertions(+), 24 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 4500d050e..e131553f4 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -265,12 +265,17 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): definition_positions[stmt.name.value] = i + 1 elif isinstance(stmt, cst.SimpleStatementLine): - # Check for assignments + # Check for assignments (including tuple unpacking and chained assignments) for child in stmt.body: if isinstance(child, cst.Assign): for target in child.targets: - if isinstance(target.target, cst.Name): - definition_positions[target.target.value] = i + 1 + # Handle all target types (Name, Tuple, etc.) + for name in _extract_names_from_target(target.target): + definition_positions[name] = i + 1 + elif isinstance(child, cst.AnnAssign): + # Handle annotated assignments + for name in _extract_names_from_target(child.target): + definition_positions[name] = i + 1 # Find the default insertion index (after imports) default_insert_index = find_insertion_index_after_imports(updated_node) @@ -433,9 +438,20 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: if not self.in_function_or_class and self.compound_depth == 0: for statement in node.body: # Skip imports - if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)): + if isinstance(statement, (cst.Import, cst.ImportFrom)): + continue + # Skip simple name assignments (handled by GlobalAssignmentCollector) + # But include tuple unpacking and chained assignments + if isinstance(statement, cst.Assign): + # Check if it's a simple single-name assignment + if len(statement.targets) == 1 and isinstance(statement.targets[0].target, cst.Name): + continue + # Tuple unpacking, chained assignment, etc. - include it self.global_statements.append(node) break + # Include other statement types (AnnAssign, Expr, etc.) + self.global_statements.append(node) + break class LastImportFinder(cst.CSTVisitor): @@ -564,7 +580,11 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c def get_statement_dependencies(stmt: cst.BaseStatement) -> set[str]: - """Extract all names that a statement depends on (names used on the RHS).""" + """Extract all names that a statement depends on (names used on the RHS). + + Note: Type annotations are NOT included as dependencies since they only affect + static type checking and don't cause runtime NameErrors. + """ deps: set[str] = set() if isinstance(stmt, cst.SimpleStatementLine): for body_item in stmt.body: @@ -572,37 +592,97 @@ def get_statement_dependencies(stmt: cst.BaseStatement) -> set[str]: # Get names used in the value (RHS) deps.update(get_names_in_expression(body_item.value)) elif isinstance(body_item, cst.AnnAssign) and body_item.value is not None: - # Get names used in the value and annotation + # Only get names from the value, NOT from the annotation + # Type annotations don't create runtime dependencies deps.update(get_names_in_expression(body_item.value)) - deps.update(get_names_in_expression(body_item.annotation.annotation)) return deps -def get_statement_defined_name(stmt: cst.BaseStatement) -> str | None: - """Get the name defined by a statement (LHS of assignment).""" - if isinstance(stmt, cst.SimpleStatementLine) and len(stmt.body) == 1: - body_item = stmt.body[0] - if isinstance(body_item, cst.Assign) and len(body_item.targets) == 1: - target = body_item.targets[0].target - if isinstance(target, cst.Name): - return target.value - elif isinstance(body_item, cst.AnnAssign): - target = body_item.target - if isinstance(target, cst.Name): - return target.value - return None +def _extract_names_from_target(target: cst.BaseExpression) -> set[str]: + """Extract all names from an assignment target (handles tuples, names, etc.).""" + names: set[str] = set() + if isinstance(target, cst.Name): + names.add(target.value) + elif isinstance(target, (cst.Tuple, cst.List)): + for element in target.elements: + if isinstance(element, (cst.Element, cst.StarredElement)): + names.update(_extract_names_from_target(element.value)) + return names + + +def get_statement_defined_names(stmt: cst.BaseStatement) -> set[str]: + """Get all names defined by a statement (LHS of assignment).""" + names: set[str] = set() + if isinstance(stmt, cst.SimpleStatementLine): + for body_item in stmt.body: + if isinstance(body_item, cst.Assign): + # Handle chained assignments: a = b = c = 5 + for target_node in body_item.targets: + names.update(_extract_names_from_target(target_node.target)) + elif isinstance(body_item, cst.AnnAssign): + names.update(_extract_names_from_target(body_item.target)) + return names def find_last_definition_index(name: str, body: list[cst.BaseStatement]) -> int: """Find the index of the last statement that defines a given name.""" last_idx = -1 for i, stmt in enumerate(body): - defined_name = get_statement_defined_name(stmt) - if defined_name == name: + defined_names = get_statement_defined_names(stmt) + if name in defined_names: last_idx = i return last_idx +def _sort_statements_by_dependencies(statements: list[cst.BaseStatement]) -> list[cst.BaseStatement]: + """Sort statements so that definitions come before their dependents. + + Uses a stable topological sort - statements without dependencies on each other + maintain their original relative order. + """ + if len(statements) <= 1: + return statements + + # Build a map of defined names to statement index + name_to_idx: dict[str, int] = {} + for i, stmt in enumerate(statements): + for name in get_statement_defined_names(stmt): + name_to_idx[name] = i + + # Build dependency graph: idx -> set of indices it depends on + deps_graph: dict[int, set[int]] = {i: set() for i in range(len(statements))} + for i, stmt in enumerate(statements): + for dep_name in get_statement_dependencies(stmt): + if dep_name in name_to_idx: + dep_idx = name_to_idx[dep_name] + if dep_idx != i: # Don't add self-dependency + deps_graph[i].add(dep_idx) + + # Kahn's algorithm for topological sort with stable ordering + in_degree = {i: len(deps) for i, deps in deps_graph.items()} + # Start with nodes that have no dependencies, in original order + queue = [i for i in range(len(statements)) if in_degree[i] == 0] + result: list[cst.BaseStatement] = [] + + while queue: + # Take from front (maintains original order for independent items) + idx = queue.pop(0) + result.append(statements[idx]) + + # Find nodes that depend on this one and reduce their in-degree + for other_idx in range(len(statements)): + if idx in deps_graph[other_idx]: + in_degree[other_idx] -= 1 + if in_degree[other_idx] == 0: + queue.append(other_idx) + + # If we couldn't sort all (cycle), return original order + if len(result) != len(statements): + return statements + + return result + + class ImportInserter(cst.CSTTransformer): """Transformer that inserts global statements into a module. @@ -635,6 +715,9 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c # For simple statements, insert after their dependencies or after imports if simple_statements: + # Sort statements by their interdependencies first + simple_statements = _sort_statements_by_dependencies(simple_statements) + # Find the base insertion point (after imports) base_insertion_index = 0 if self.last_import_line > 0: @@ -644,11 +727,13 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c break # Insert each statement at the correct position based on its dependencies - for stmt in simple_statements: + # Track how many statements have been inserted to maintain relative order + for inserted_count, stmt in enumerate(simple_statements): deps = get_statement_dependencies(stmt) # Find the position after the last dependency definition - insertion_index = base_insertion_index + # Account for already-inserted statements to maintain order + insertion_index = base_insertion_index + inserted_count for dep_name in deps: dep_idx = find_last_definition_index(dep_name, updated_body) if dep_idx >= 0: diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index bc13080a3..4e97d66f4 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -4400,6 +4400,177 @@ def remove_sentence_punctuation(s: str) -> str: ) +def test_add_global_assignments_tuple_unpacking() -> None: + """Test that tuple unpacking assignments are properly tracked. + + Verifies the fix for: a, b = 1, 2 where the target is a Tuple, not a Name. + Without the fix, neither 'a' nor 'b' would be tracked as defined. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''import sys + +def foo(): + pass +''' + + # Optimized code with tuple unpacking that a later variable depends on + optimized_code = '''import sys + +a, b = 1, 2 +c = a + b + +def foo(): + pass +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + exec(compiled, {}) + except NameError as e: + msg = f"Tuple unpacking test failed with NameError: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e + + # Verify correct ordering: a, b = 1, 2 must come before c = a + b + unpack_pos = result.index("a, b = 1, 2") + c_pos = result.index("c = a + b") + assert unpack_pos < c_pos, "Tuple unpacking must come before dependent assignment" + + +def test_add_global_assignments_chained_assignment() -> None: + """Test that chained assignments are properly tracked. + + Verifies the fix for: a = b = c = 5 which has multiple targets. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''import sys + +def foo(): + pass +''' + + # Optimized code with chained assignment that a later variable depends on + optimized_code = '''import sys + +a = b = c = 5 +d = a + b + c + +def foo(): + pass +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + exec(compiled, {}) + except NameError as e: + msg = f"Chained assignment test failed with NameError: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e + + # Verify correct ordering: a = b = c = 5 must come before d = a + b + c + chain_pos = result.index("a = b = c = 5") + d_pos = result.index("d = a + b + c") + assert chain_pos < d_pos, "Chained assignment must come before dependent assignment" + + +def test_add_global_assignments_multiple_new_statements() -> None: + """Test that multiple new statements maintain correct order. + + When inserting multiple statements with no dependencies, they should + maintain their relative order from the optimized code. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''import sys + +def foo(): + pass +''' + + # Optimized code with multiple independent global assignments + optimized_code = '''import sys + +FIRST = 1 +SECOND = 2 +THIRD = 3 + +def foo(): + pass +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid + try: + compiled = compile(result, "", "exec") + exec(compiled, {}) + except (NameError, SyntaxError) as e: + msg = f"Multiple statements test failed: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e + + # Verify correct ordering: FIRST, SECOND, THIRD in order + first_pos = result.index("FIRST = 1") + second_pos = result.index("SECOND = 2") + third_pos = result.index("THIRD = 3") + assert first_pos < second_pos < third_pos, ( + f"Statements should maintain order: FIRST ({first_pos}) < SECOND ({second_pos}) < THIRD ({third_pos})" + ) + + +def test_add_global_assignments_annotated_no_spurious_deps() -> None: + """Test that type annotations don't create spurious dependencies. + + Verifies the fix for: x: Tuple[int, int] = value + Without the fix, Tuple and int would be added as spurious dependencies. + """ + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = '''import sys + +val = (1, 2) + +def foo(): + pass +''' + + # Optimized code with type annotation - the annotation uses Tuple[int, int] + # but the actual value only depends on 'val' + optimized_code = '''import sys + +val = (1, 2) +x: tuple[int, int] = val +y = x + +def foo(): + pass +''' + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + exec(compiled, {}) + except NameError as e: + msg = f"Annotated assignment test failed with NameError: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e + + # Verify x assignment is in the result + assert "x: tuple[int, int] = val" in result, "Annotated assignment should be present" + + # Verify correct ordering: val must come before x, x must come before y + val_pos = result.index("val = (1, 2)") + x_pos = result.index("x: tuple[int, int] = val") + y_pos = result.index("y = x") + assert val_pos < x_pos < y_pos, "Variables should be in dependency order" + + # ============================================================================= # Real-world standardize_quotes optimization tests # These tests verify the fixes work for the actual optimization scenarios From 9efbabd3fdf635aa9cdcedb3763ea78a20c39e52 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 21 Jan 2026 03:11:55 -0500 Subject: [PATCH 13/14] fix: handle match statements, comprehension/lambda scoping, and circular dependency warnings - Add match statement (Python 3.10+) handling in get_statement_defined_names() - Fix comprehension variables leaking as external dependencies in NameCollector - Fix lambda parameters leaking as external dependencies - Add warning when circular dependencies are detected in _sort_statements_by_dependencies() - Add Match handling to GlobalAssignmentCollector, GlobalAssignmentTransformer, and GlobalStatementCollector --- codeflash/code_utils/code_extractor.py | 311 +++++++++++++++++++++++-- tests/test_code_replacement.py | 272 +++++++++++++++++++++ 2 files changed, 568 insertions(+), 15 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index e131553f4..b7588aa8b 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -30,18 +30,148 @@ class NameCollector(cst.CSTVisitor): """Collects all Name nodes referenced in a CST expression. Used to find what names an assignment depends on (e.g., function calls in the RHS). + Properly scopes comprehension variables and lambda parameters so they are not + treated as external dependencies. """ def __init__(self) -> None: super().__init__() self.names: set[str] = set() + # Stack of sets tracking locally bound names (comprehension vars, lambda params) + self._local_names_stack: list[set[str]] = [] + + def _is_local_name(self, name: str) -> bool: + """Check if a name is bound in any enclosing local scope.""" + return any(name in scope for scope in self._local_names_stack) def visit_Name(self, node: cst.Name) -> None: - self.names.add(node.value) + if not self._is_local_name(node.value): + self.names.add(node.value) + + # --- Comprehension handling --- + def visit_ListComp(self, node: cst.ListComp) -> bool: + """Handle list comprehensions: [expr for x in iterable if cond].""" + self._visit_comprehension(node.for_in, node.elt) + return False # Don't visit children - we handle them manually + + def visit_SetComp(self, node: cst.SetComp) -> bool: + """Handle set comprehensions: {expr for x in iterable if cond}.""" + self._visit_comprehension(node.for_in, node.elt) + return False + + def visit_DictComp(self, node: cst.DictComp) -> bool: + """Handle dict comprehensions: {key: value for x in iterable if cond}.""" + # For dict comps, we need to visit both key and value + self._visit_dict_comprehension(node.for_in, node.key, node.value) + return False + + def visit_GeneratorExp(self, node: cst.GeneratorExp) -> bool: + """Handle generator expressions: (expr for x in iterable if cond).""" + self._visit_comprehension(node.for_in, node.elt) + return False + + def _extract_comp_target_names(self, target: cst.BaseExpression) -> set[str]: + """Extract names bound by a comprehension target (e.g., x in 'for x in items').""" + return _extract_names_from_target(target) + + def _visit_comprehension(self, for_in: cst.CompFor, elt: cst.BaseExpression) -> None: + """Visit a comprehension, properly scoping the loop variables.""" + # Create a new scope for this comprehension + local_scope: set[str] = set() + self._local_names_stack.append(local_scope) + + # Process all for..in clauses, collecting bound names as we go + self._visit_comp_for(for_in, local_scope) + + # Visit the element expression (all bound names are now in scope) + elt.visit(self) + + # Pop the scope + self._local_names_stack.pop() + + def _visit_dict_comprehension( + self, for_in: cst.CompFor, key: cst.BaseExpression, value: cst.BaseExpression + ) -> None: + """Visit a dict comprehension, properly scoping the loop variables.""" + local_scope: set[str] = set() + self._local_names_stack.append(local_scope) + + self._visit_comp_for(for_in, local_scope) + + # Visit both key and value expressions + key.visit(self) + value.visit(self) + + self._local_names_stack.pop() + + def _visit_comp_for(self, comp_for: cst.CompFor, local_scope: set[str]) -> None: + """Recursively visit CompFor nodes, adding bound names to scope.""" + # First, visit the iterable (before adding target names to scope!) + # This is important: `x` in `[x for x in x]` - the RHS `x` is external + comp_for.iter.visit(self) + + # Add the target names to the local scope + target_names = self._extract_comp_target_names(comp_for.target) + local_scope.update(target_names) + + # Visit any conditions (ifs) + for if_clause in comp_for.ifs: + if_clause.test.visit(self) + + # Visit nested for..in if present + if comp_for.inner_for_in is not None: + self._visit_comp_for(comp_for.inner_for_in, local_scope) + + # --- Lambda handling --- + def visit_Lambda(self, node: cst.Lambda) -> bool: + """Handle lambda expressions: lambda x, y: x + y + z.""" + # Collect parameter names + param_names = self._extract_lambda_param_names(node.params) + + # Push a new scope with these parameters + self._local_names_stack.append(param_names) + + # Visit the lambda body + node.body.visit(self) + + # Pop the scope + self._local_names_stack.pop() + + return False # Don't visit children - we handle them manually + + def _extract_lambda_param_names(self, params: cst.Parameters) -> set[str]: + """Extract all parameter names from a lambda's parameters.""" + names: set[str] = set() + + # Regular positional/keyword params + for param in params.params: + names.add(param.name.value) + + # *args + if params.star_arg and isinstance(params.star_arg, cst.Param): + names.add(params.star_arg.name.value) + + # Keyword-only params + for param in params.kwonly_params: + names.add(param.name.value) + + # **kwargs + if params.star_kwarg: + names.add(params.star_kwarg.name.value) + + # Positional-only params (Python 3.8+) + for param in params.posonly_params: + names.add(param.name.value) + + return names def get_names_in_expression(node: cst.BaseExpression) -> set[str]: - """Extract all names referenced in a CST expression.""" + """Extract all names referenced in a CST expression. + + Comprehension variables and lambda parameters are properly scoped and + are NOT included as external dependencies. + """ collector = NameCollector() node.visit(collector) return collector.names @@ -120,6 +250,13 @@ def visit_Try(self, node: cst.Try) -> Optional[bool]: # noqa: ARG002 def leave_Try(self, original_node: cst.Try) -> None: # noqa: ARG002 self.compound_depth -= 1 + def visit_Match(self, node: cst.Match) -> Optional[bool]: # noqa: ARG002 + self.compound_depth += 1 + return True + + def leave_Match(self, original_node: cst.Match) -> None: # noqa: ARG002 + self.compound_depth -= 1 + def visit_Assign(self, node: cst.Assign) -> Optional[bool]: # Only process global assignments (not inside functions, classes, loops, etc.) if self.scope_depth == 0 and self.if_else_depth == 0 and self.compound_depth == 0: @@ -230,6 +367,13 @@ def leave_Try(self, original_node: cst.Try, updated_node: cst.Try) -> cst.Try: self.compound_depth -= 1 return updated_node + def visit_Match(self, node: cst.Match) -> None: # noqa: ARG002 + self.compound_depth += 1 + + def leave_Match(self, original_node: cst.Match, updated_node: cst.Match) -> cst.Match: # noqa: ARG002 + self.compound_depth -= 1 + return updated_node + def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode: if self.scope_depth > 0 or self.if_else_depth > 0 or self.compound_depth > 0: return updated_node @@ -264,18 +408,10 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c for i, stmt in enumerate(new_statements): if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): definition_positions[stmt.name.value] = i + 1 - elif isinstance(stmt, cst.SimpleStatementLine): - # Check for assignments (including tuple unpacking and chained assignments) - for child in stmt.body: - if isinstance(child, cst.Assign): - for target in child.targets: - # Handle all target types (Name, Tuple, etc.) - for name in _extract_names_from_target(target.target): - definition_positions[name] = i + 1 - elif isinstance(child, cst.AnnAssign): - # Handle annotated assignments - for name in _extract_names_from_target(child.target): - definition_positions[name] = i + 1 + # Get all names defined by this statement (handles simple assignments, + # compound statements like for/while/with/try, etc.) + for name in get_statement_defined_names(stmt): + definition_positions[name] = i + 1 # Find the default insertion index (after imports) default_insert_index = find_insertion_index_after_imports(updated_node) @@ -434,6 +570,15 @@ def visit_Try(self, node: cst.Try) -> Optional[bool]: def leave_Try(self, original_node: cst.Try) -> None: # noqa: ARG002 self.compound_depth -= 1 + def visit_Match(self, node: cst.Match) -> Optional[bool]: + if not self.in_function_or_class and self.compound_depth == 0: + self.global_statements.append(node) + self.compound_depth += 1 + return False # Don't visit children - we collect the whole node + + def leave_Match(self, original_node: cst.Match) -> None: # noqa: ARG002 + self.compound_depth -= 1 + def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: if not self.in_function_or_class and self.compound_depth == 0: for statement in node.body: @@ -610,8 +755,87 @@ def _extract_names_from_target(target: cst.BaseExpression) -> set[str]: return names +def _extract_names_from_pattern(pattern: cst.MatchPattern) -> set[str]: + """Extract all names bound by a match pattern (Python 3.10+). + + Match patterns can bind names in several ways: + - MatchAs: `case x:` or `case _ as x:` + - MatchStar: `case [*rest]:` + - MatchMapping: `case {"key": value}:` + - MatchSequence/MatchList/MatchTuple: `case [x, y]:` + - MatchClass: `case Point(x=px, y=py):` + - MatchOr: `case x | y:` (names must be same in all alternatives) + """ + names: set[str] = set() + + if isinstance(pattern, cst.MatchAs): + # `case x:` or `case _ as x:` - the name is bound + if pattern.name is not None: + names.add(pattern.name.value) + # Also check nested pattern (e.g., `case [a, b] as both:`) + if pattern.pattern is not None: + names.update(_extract_names_from_pattern(pattern.pattern)) + + elif isinstance(pattern, cst.MatchStar): + # `case [first, *rest]:` - rest is bound + if pattern.name is not None: + names.add(pattern.name.value) + + elif isinstance(pattern, (cst.MatchSequence, cst.MatchList, cst.MatchTuple)): + # `case [x, y]:` or `case (x, y):` + for element in pattern.patterns: + if isinstance(element, cst.MatchSequenceElement): + names.update(_extract_names_from_pattern(element.value)) + elif isinstance(element, cst.MatchStar) and element.name is not None: + names.add(element.name.value) + + elif isinstance(pattern, cst.MatchMapping): + # `case {"key": value}:` + for element in pattern.elements: + if isinstance(element, cst.MatchMappingElement): + names.update(_extract_names_from_pattern(element.pattern)) + # Also handle `case {**rest}:` + if pattern.rest is not None: + names.add(pattern.rest.value) + + elif isinstance(pattern, cst.MatchClass): + # `case Point(x, y):` or `case Point(x=px, y=py):` + for arg in pattern.patterns: + if isinstance(arg, cst.MatchSequenceElement): + names.update(_extract_names_from_pattern(arg.value)) + for kwarg in pattern.keywords: + if isinstance(kwarg, cst.MatchKeywordElement): + names.update(_extract_names_from_pattern(kwarg.pattern)) + + elif isinstance(pattern, cst.MatchOr): + # `case x | y:` - all alternatives must bind same names + for element in pattern.patterns: + if isinstance(element, cst.MatchOrElement): + names.update(_extract_names_from_pattern(element.pattern)) + + # MatchValue and MatchSingleton don't bind names + + return names + + +def _collect_defined_names_from_body(body: cst.BaseSuite) -> set[str]: + """Recursively collect all names defined by assignments inside a code block. + + This is used to find names defined inside compound statements (for, while, with, try). + """ + names: set[str] = set() + if isinstance(body, cst.IndentedBlock): + for stmt in body.body: + names.update(get_statement_defined_names(stmt)) + return names + + def get_statement_defined_names(stmt: cst.BaseStatement) -> set[str]: - """Get all names defined by a statement (LHS of assignment).""" + """Get all names defined by a statement (LHS of assignment). + + For compound statements (for, while, with, try, if), this also collects + names defined inside their bodies. + """ names: set[str] = set() if isinstance(stmt, cst.SimpleStatementLine): for body_item in stmt.body: @@ -621,6 +845,52 @@ def get_statement_defined_names(stmt: cst.BaseStatement) -> set[str]: names.update(_extract_names_from_target(target_node.target)) elif isinstance(body_item, cst.AnnAssign): names.update(_extract_names_from_target(body_item.target)) + elif isinstance(stmt, cst.For): + # For loop target variable + names.update(_extract_names_from_target(stmt.target)) + # Names defined inside the for loop body + names.update(_collect_defined_names_from_body(stmt.body)) + if stmt.orelse: + names.update(_collect_defined_names_from_body(stmt.orelse.body)) + elif isinstance(stmt, cst.While): + # Names defined inside the while loop body + names.update(_collect_defined_names_from_body(stmt.body)) + if stmt.orelse: + names.update(_collect_defined_names_from_body(stmt.orelse.body)) + elif isinstance(stmt, cst.With): + # With statement can bind names via 'as' clause + for item in stmt.items: + if item.asname: + names.update(_extract_names_from_target(item.asname.name)) + # Names defined inside the with body + names.update(_collect_defined_names_from_body(stmt.body)) + elif isinstance(stmt, cst.Try): + # Names defined in try/except/else/finally bodies + names.update(_collect_defined_names_from_body(stmt.body)) + for handler in stmt.handlers: + if handler.name: + names.add(handler.name.value) + names.update(_collect_defined_names_from_body(handler.body)) + if stmt.orelse: + names.update(_collect_defined_names_from_body(stmt.orelse.body)) + if stmt.finalbody: + names.update(_collect_defined_names_from_body(stmt.finalbody.body)) + elif isinstance(stmt, cst.If): + # Names defined inside if/elif/else bodies + names.update(_collect_defined_names_from_body(stmt.body)) + if stmt.orelse: + if isinstance(stmt.orelse, cst.Else): + names.update(_collect_defined_names_from_body(stmt.orelse.body)) + elif isinstance(stmt.orelse, cst.If): + # elif branch + names.update(get_statement_defined_names(stmt.orelse)) + elif isinstance(stmt, cst.Match): + # Match statement (Python 3.10+) can bind names through patterns + for case in stmt.cases: + # Extract names bound by the pattern + names.update(_extract_names_from_pattern(case.pattern)) + # Names defined inside the case body + names.update(_collect_defined_names_from_body(case.body)) return names @@ -678,6 +948,17 @@ def _sort_statements_by_dependencies(statements: list[cst.BaseStatement]) -> lis # If we couldn't sort all (cycle), return original order if len(result) != len(statements): + # Find the statements involved in the cycle for a more helpful warning + unsorted_indices = [i for i in range(len(statements)) if in_degree[i] > 0] + cycle_names = [] + for idx in unsorted_indices: + names = get_statement_defined_names(statements[idx]) + if names: + cycle_names.extend(names) + logger.warning( + f"Circular dependency detected among statements defining: {', '.join(cycle_names) or 'unknown'}. " + "Using original statement order. This may cause NameError at runtime." + ) return statements return result diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 4e97d66f4..381f32409 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -5204,3 +5204,275 @@ def test_standardize_quotes_failing(): # Count remaining test functions test_count = final_source.count("def test_") assert test_count == 4, f"Should have 4 tests after removing failed one, got {test_count}" + + +def test_match_statement_pattern_names(): + """Test that match statement patterns define names correctly.""" + from codeflash.code_utils.code_extractor import get_statement_defined_names + + # Test basic match with MatchAs + code = """ +match value: + case [x, y]: + result = x + y + case {"name": name}: + other = name +""" + module = cst.parse_module(code) + match_stmt = module.body[0] + defined_names = get_statement_defined_names(match_stmt) + + # Should find x, y from first case, name from second case, result, other from bodies + assert "x" in defined_names, "Pattern variable 'x' should be defined" + assert "y" in defined_names, "Pattern variable 'y' should be defined" + assert "name" in defined_names, "Pattern variable 'name' should be defined" + assert "result" in defined_names, "Body variable 'result' should be defined" + assert "other" in defined_names, "Body variable 'other' should be defined" + + +def test_match_statement_star_pattern(): + """Test that match statement star patterns define names correctly.""" + from codeflash.code_utils.code_extractor import get_statement_defined_names + + code = """ +match items: + case [first, *rest]: + total = sum(rest) +""" + module = cst.parse_module(code) + match_stmt = module.body[0] + defined_names = get_statement_defined_names(match_stmt) + + assert "first" in defined_names, "Pattern variable 'first' should be defined" + assert "rest" in defined_names, "Star pattern variable 'rest' should be defined" + assert "total" in defined_names, "Body variable 'total' should be defined" + + +def test_match_statement_as_pattern(): + """Test that match statement 'as' patterns define names correctly.""" + from codeflash.code_utils.code_extractor import get_statement_defined_names + + code = """ +match obj: + case [a, b] as both: + use = both +""" + module = cst.parse_module(code) + match_stmt = module.body[0] + defined_names = get_statement_defined_names(match_stmt) + + assert "a" in defined_names, "Pattern variable 'a' should be defined" + assert "b" in defined_names, "Pattern variable 'b' should be defined" + assert "both" in defined_names, "'as' capture variable 'both' should be defined" + + +def test_match_statement_mapping_rest(): + """Test that match statement mapping rest patterns define names correctly.""" + from codeflash.code_utils.code_extractor import get_statement_defined_names + + code = """ +match config: + case {"known": val, **rest}: + extra = rest +""" + module = cst.parse_module(code) + match_stmt = module.body[0] + defined_names = get_statement_defined_names(match_stmt) + + assert "val" in defined_names, "Mapping value 'val' should be defined" + assert "rest" in defined_names, "Mapping rest '**rest' should be defined" + + +def test_comprehension_variable_not_dependency(): + """Test that comprehension variables aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "result = [x * 2 for x in items]" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "x" not in deps, "Comprehension variable 'x' should not be a dependency" + assert "items" in deps, "Iterable 'items' should be a dependency" + + +def test_nested_comprehension_variables_not_dependencies(): + """Test that nested comprehension variables aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "matrix = [[row[col] for col in range(n)] for row in data]" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "row" not in deps, "Outer comprehension variable 'row' should not be a dependency" + assert "col" not in deps, "Inner comprehension variable 'col' should not be a dependency" + assert "data" in deps, "Outer iterable 'data' should be a dependency" + assert "n" in deps, "Inner range argument 'n' should be a dependency" + assert "range" in deps, "Built-in 'range' should be a dependency" + + +def test_dict_comprehension_variables_not_dependencies(): + """Test that dict comprehension variables aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "mapping = {k: v * 2 for k, v in items.items()}" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "k" not in deps, "Dict comprehension key variable 'k' should not be a dependency" + assert "v" not in deps, "Dict comprehension value variable 'v' should not be a dependency" + assert "items" in deps, "Iterable 'items' should be a dependency" + + +def test_generator_expression_variable_not_dependency(): + """Test that generator expression variables aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "gen = (x for x in items if x > 0)" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "x" not in deps, "Generator expression variable 'x' should not be a dependency" + assert "items" in deps, "Iterable 'items' should be a dependency" + + +def test_lambda_parameter_not_dependency(): + """Test that lambda parameters aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "handler = lambda x, y: x + y + z" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "x" not in deps, "Lambda parameter 'x' should not be a dependency" + assert "y" not in deps, "Lambda parameter 'y' should not be a dependency" + assert "z" in deps, "Free variable 'z' should be a dependency" + + +def test_lambda_star_args_not_dependency(): + """Test that lambda *args and **kwargs aren't external dependencies.""" + from codeflash.code_utils.code_extractor import get_statement_dependencies + + code = "handler = lambda *args, **kwargs: process(args, kwargs, config)" + module = cst.parse_module(code) + deps = get_statement_dependencies(module.body[0]) + + assert "args" not in deps, "Lambda *args should not be a dependency" + assert "kwargs" not in deps, "Lambda **kwargs should not be a dependency" + assert "process" in deps, "Free variable 'process' should be a dependency" + assert "config" in deps, "Free variable 'config' should be a dependency" + + +def test_circular_dependency_warning(caplog): + """Test that circular dependencies produce a warning.""" + import logging + + from codeflash.code_utils.code_extractor import _sort_statements_by_dependencies + + code = """ +x = y + 1 +y = x + 1 +""" + module = cst.parse_module(code) + + with caplog.at_level(logging.WARNING): + result = _sort_statements_by_dependencies(list(module.body)) + + # Should emit a warning about circular dependency + assert any("Circular dependency detected" in record.message for record in caplog.records), ( + "Should log a warning about circular dependencies" + ) + # Should return original order when cycle is detected + assert len(result) == 2, "Should return all statements" + + +def test_add_global_assignments_with_match_statement(): + """Test that match statements in optimized code are handled correctly.""" + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = """def foo(): + pass +""" + + optimized_code = """ +match config: + case {"type": t}: + result = t + case _: + result = "default" + +use_result = result + +def foo(): + pass +""" + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code contains the match statement + assert "match config" in result, "Match statement should be in result" + assert "use_result = result" in result, "Variable using match result should be in result" + + # The code should be syntactically valid (but may have runtime NameErrors + # due to undefined 'config') + compiled = compile(result, "", "exec") + assert compiled is not None + + +def test_comprehension_in_global_assignment(): + """Test that global assignments with comprehensions work correctly.""" + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = """def foo(): + pass +""" + + optimized_code = """ +items = [1, 2, 3, 4, 5] +doubled = [x * 2 for x in items] +filtered = [x for x in doubled if x > 4] + +def foo(): + pass +""" + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + namespace = {} + exec(compiled, namespace) + assert namespace["doubled"] == [2, 4, 6, 8, 10] + assert namespace["filtered"] == [6, 8, 10] + except NameError as e: + msg = f"Comprehension test failed with NameError: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e + + +def test_lambda_in_global_assignment(): + """Test that global assignments with lambdas work correctly.""" + from codeflash.code_utils.code_extractor import add_global_assignments + + original_code = """def foo(): + pass +""" + + optimized_code = """ +multiplier = 2 +scale = lambda x: x * multiplier + +def foo(): + pass +""" + + result = add_global_assignments(optimized_code, original_code) + + # Verify the code is valid and executes without NameError + try: + compiled = compile(result, "", "exec") + namespace = {} + exec(compiled, namespace) + assert namespace["scale"](5) == 10 + except NameError as e: + msg = f"Lambda test failed with NameError: {e}\n\nGenerated code:\n{result}" + raise AssertionError(msg) from e From 332eb559cd36cdd9933ad693fd62ca4c875fd049 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 25 Jan 2026 05:14:36 -0500 Subject: [PATCH 14/14] fix: correct variable name in add_global_assignments --- codeflash/code_utils/code_extractor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 84b9cc719..5f296a8a3 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1255,11 +1255,11 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: # Insert new function definitions if any if new_function_defs: - last_import_line = find_last_import_line(mod_dst_code) + last_import_line = find_last_import_line(dst_module_code) transformer = FunctionDefInserter(new_function_defs, last_import_line) modified_module = original_module.visit(transformer) - mod_dst_code = modified_module.code - original_module = cst.parse_module(mod_dst_code) + dst_module_code = modified_module.code + original_module = cst.parse_module(dst_module_code) # Parse the src_module_code once only (already done above: src_module) # Collect assignments from the new file