Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,75 @@ def optimize_python_code( # noqa: D417
console.rule()
return []

def augmented_optimize( # noqa: D417
self,
source_code: str,
system_prompt: str,
user_prompt: str,
trace_id: str,
dependency_code: str | None = None,
n_candidates: int = 3,
) -> list[OptimizedCandidate]:
"""Optimize code with custom prompts via /ai/augmented-optimize endpoint.

Parameters
----------
- source_code (str): The python code to optimize (markdown format).
- system_prompt (str): Custom system prompt for the LLM.
- user_prompt (str): Custom user prompt for the LLM.
- trace_id (str): Trace id of optimization run.
- dependency_code (str | None): Optional dependency code context.
- n_candidates (int): Number of candidates to generate (max 3).

Returns
-------
- list[OptimizedCandidate]: A list of Optimization Candidates.

"""
logger.info("Generating augmented optimization candidates…")
console.rule()
start_time = time.perf_counter()
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()

payload = {
"source_code": source_code,
"dependency_code": dependency_code,
"trace_id": trace_id,
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"n_candidates": min(n_candidates, 3),
"python_version": platform.python_version(),
"codeflash_version": codeflash_version,
"current_username": get_last_commit_author_if_pr_exists(None),
"repo_owner": git_repo_owner,
"repo_name": git_repo_name,
}
logger.debug(f"Sending augmented optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}")

try:
response = self.make_ai_service_request("/augmented-optimize", payload=payload, timeout=self.timeout)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating augmented optimization candidates: {e}")
ph("cli-augmented-optimize-error-caught", {"error": str(e)})
console.rule()
return []

if response.status_code == 200:
optimizations_json = response.json()["optimizations"]
end_time = time.perf_counter()
logger.debug(f"!lsp|Generating augmented optimizations took {end_time - start_time:.2f} seconds.")
logger.info(f"!lsp|Received {len(optimizations_json)} augmented optimization candidates.")
console.rule()
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.AUGMENTED)
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating augmented optimization candidates: {response.status_code} - {error}")
ph("cli-augmented-optimize-error-response", {"response_status_code": response.status_code, "error": error})
console.rule()
return []

def get_jit_rewritten_code( # noqa: D417
self, source_code: str, trace_id: str
) -> list[OptimizedCandidate]:
Expand Down
42 changes: 42 additions & 0 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.cmd_create_pr import create_pr as cmd_create_pr
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.extension import install_vscode_extension
Expand Down Expand Up @@ -57,6 +58,22 @@ def parse_args() -> Namespace:
help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.",
)

# create-pr subcommand for creating PRs from augmented optimization results
create_pr_parser = subparsers.add_parser("create-pr", help="Create a PR from previously applied optimizations")
create_pr_parser.set_defaults(func=cmd_create_pr)
create_pr_parser.add_argument(
"--results-file",
type=str,
default="codeflash_phase1_results.json",
help="Path to augmented output JSON file (default: codeflash_phase1_results.json)",
)
create_pr_parser.add_argument(
"--function", type=str, help="Function name (required if multiple functions in results)"
)
create_pr_parser.add_argument(
"--git-remote", type=str, default="origin", help="Git remote to use for PR creation (default: origin)"
)

parser.add_argument("--file", help="Try to optimize only this file")
parser.add_argument("--function", help="Try to optimize only this function within the given file path")
parser.add_argument(
Expand Down Expand Up @@ -117,6 +134,22 @@ def parse_args() -> Namespace:
parser.add_argument(
"--effort", type=str, help="Effort level for optimization", choices=["low", "medium", "high"], default="medium"
)
parser.add_argument(
"--augmented",
action="store_true",
help="Enable augmented optimization mode for two-phase optimization workflow",
)
parser.add_argument(
"--augmented-prompt-file",
type=str,
help="Path to YAML file with custom system_prompt and user_prompt for Phase 2",
)
parser.add_argument(
"--augmented-output",
type=str,
default="codeflash_phase1_results.json",
help="Path to write Phase 1 results JSON (default: codeflash_phase1_results.json)",
)

args, unknown_args = parser.parse_known_args()
sys.argv[:] = [sys.argv[0], *unknown_args]
Expand Down Expand Up @@ -175,6 +208,15 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
"Async function optimization is now enabled by default."
)

if args.augmented_prompt_file and not args.augmented:
exit_with_message("--augmented-prompt-file requires --augmented flag", error_on_exit=True)

if args.augmented_prompt_file:
prompt_file = Path(args.augmented_prompt_file)
if not prompt_file.exists():
exit_with_message(f"Augmented prompt file {args.augmented_prompt_file} does not exist", error_on_exit=True)
args.augmented_prompt_file = prompt_file.resolve()

return args


Expand Down
147 changes: 147 additions & 0 deletions codeflash/cli_cmds/cmd_create_pr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

import json
import re
from pathlib import Path
from typing import TYPE_CHECKING

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.git_utils import git_root_dir
from codeflash.models.models import Phase1Output, TestResults
from codeflash.result.create_pr import check_create_pr
from codeflash.result.explanation import Explanation

if TYPE_CHECKING:
from argparse import Namespace

# Pattern to extract file path from markdown code block: ```python:path/to/file.py
MARKDOWN_FILE_PATH_PATTERN = re.compile(r"```python:([^\n]+)")


def extract_file_path_from_markdown(markdown_code: str) -> str | None:
"""Extract the file path from markdown code block format.

Format: ```python:path/to/file.py
"""
match = MARKDOWN_FILE_PATH_PATTERN.search(markdown_code)
if match:
return match.group(1).strip()
return None


def extract_code_from_markdown(markdown_code: str) -> str:
r"""Extract the code content from markdown code block.

Removes the ```python:path\n ... ``` wrapper.
"""
# Remove opening markdown fence with optional path
code = re.sub(r"^```python(?::[^\n]*)?\n", "", markdown_code)
# Remove closing fence
return re.sub(r"\n```$", "", code)


def create_pr(args: Namespace) -> None:
"""Create a PR from previously applied optimizations."""
results_file = Path(args.results_file)

if not results_file.exists():
exit_with_message(f"Results file not found: {results_file}", error_on_exit=True)

# Load and parse results
with results_file.open(encoding="utf-8") as f:
data = json.load(f)

try:
output = Phase1Output.model_validate(data)
except Exception as e:
exit_with_message(f"Failed to parse results file: {e}", error_on_exit=True)

# Find the function result
if len(output.functions) == 0:
exit_with_message("No functions in results file", error_on_exit=True)

if len(output.functions) > 1 and not args.function:
func_names = [f.function_name for f in output.functions]
exit_with_message(
f"Multiple functions in results. Specify one with --function: {func_names}", error_on_exit=True
)

func_result = output.functions[0]
if args.function:
func_result = next((f for f in output.functions if f.function_name == args.function), None)
if not func_result:
exit_with_message(f"Function {args.function} not found in results", error_on_exit=True)
assert func_result is not None # for type checker - exit_with_message doesn't return

if not func_result.best_candidate_id:
exit_with_message("No successful optimization found in results", error_on_exit=True)

# Get file path - prefer explicit field, fall back to extracting from markdown
file_path_str = func_result.file_path
if not file_path_str:
file_path_str = extract_file_path_from_markdown(func_result.original_source_code)

if not file_path_str:
exit_with_message(
"Could not determine file path from results. Results file may be from an older version of codeflash.",
error_on_exit=True,
)
assert file_path_str is not None # for type checker - exit_with_message doesn't return

file_path = Path(file_path_str)
if not file_path.exists():
exit_with_message(f"Source file not found: {file_path}", error_on_exit=True)

# Read current (optimized) file content
current_content = file_path.read_text(encoding="utf-8")

# Extract original code (strip markdown)
original_code = extract_code_from_markdown(func_result.original_source_code)

# Get the best candidate's explanation
best_explanation = func_result.best_candidate_explanation
if not best_explanation:
# Fall back to the candidate's explanation if the final explanation wasn't captured
best_candidate = next(
(c for c in func_result.candidates if c.optimization_id == func_result.best_candidate_id), None
)
best_explanation = best_candidate.explanation if best_candidate else "Optimization applied"

# Build Explanation object for PR creation
explanation = Explanation(
raw_explanation_message=best_explanation,
winning_behavior_test_results=TestResults(),
winning_benchmarking_test_results=TestResults(),
original_runtime_ns=func_result.original_runtime_ns or 0,
best_runtime_ns=func_result.best_runtime_ns or func_result.original_runtime_ns or 0,
function_name=func_result.function_name,
file_path=file_path,
)

logger.info(f"Creating PR for optimized function: {func_result.function_name}")
logger.info(f"File: {file_path}")
if func_result.best_speedup_ratio:
logger.info(f"Speedup: {func_result.best_speedup_ratio * 100:.1f}%")

# Call existing PR creation
check_create_pr(
original_code={file_path: original_code},
new_code={file_path: current_content},
explanation=explanation,
existing_tests_source=func_result.existing_tests_source or "",
generated_original_test_source="",
function_trace_id=func_result.trace_id,
coverage_message="",
replay_tests=func_result.replay_tests_source or "",
concolic_tests=func_result.concolic_tests_source or "",
optimization_review="",
root_dir=git_root_dir(),
git_remote=getattr(args, "git_remote", None),
precomputed_test_report=func_result.test_report,
precomputed_loop_count=func_result.loop_count,
)

# Cleanup results file after successful PR creation
results_file.unlink()
logger.info(f"Cleaned up results file: {results_file}")
Loading
Loading