Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
722d053
decorators & annotations
KRRT7 Jan 23, 2026
4e31032
fix: recursively extract base classes for imported dataclasses
KRRT7 Jan 23, 2026
ebf7703
test: add tests for base class extraction in imported dataclasses
KRRT7 Jan 23, 2026
9cefc03
fix: handle instance state capture for classes without __dict__
KRRT7 Jan 23, 2026
412779d
fix: handle repr failures for Mock objects in test result comparison
KRRT7 Jan 23, 2026
9f929c2
fix: handle annotated assignments in GlobalAssignmentCollector
KRRT7 Jan 23, 2026
6009b83
fix: handle module-level function definitions in add_global_assignments
KRRT7 Jan 23, 2026
dbc88ad
feat: extract __init__ from external library base classes for test co…
KRRT7 Jan 24, 2026
323e52e
fix: use PicklePatcher for instance state to handle unpicklable async…
KRRT7 Jan 24, 2026
34de676
fix: track attribute access base names as dependencies in DependencyC…
KRRT7 Jan 24, 2026
6b3b10e
fix: include dependency classes in read-writable optimization context
KRRT7 Jan 24, 2026
1bb9d14
Merge branch 'main' into skyvern-grace
KRRT7 Jan 24, 2026
abfa640
fix: insert global assignments after class definitions to prevent Nam…
KRRT7 Jan 24, 2026
50fba09
fix: insert global statements after function definitions to prevent N…
KRRT7 Jan 24, 2026
257c5f2
test: update test expectations for global assignment placement changes
KRRT7 Jan 24, 2026
7b33e8b
refactor: smarter placement of global assignments based on dependencies
KRRT7 Jan 24, 2026
48b5ff3
refactor: merge duplicate CST pruning functions into single parameter…
KRRT7 Jan 24, 2026
9c9593a
refactor: merge duplicate CST pruning functions into single parameter…
KRRT7 Jan 24, 2026
1b92d11
Update mypy.yml
KRRT7 Jan 24, 2026
aba3e79
Merge branch 'skyvern-grace' of https://github.com/codeflash-ai/codef…
KRRT7 Jan 24, 2026
571b2ba
skip for 3.9
KRRT7 Jan 24, 2026
47b5235
refactor: remove unused code and simplify return pattern in unused_de…
KRRT7 Jan 24, 2026
69740f0
refactor: simplify code_context_extractor by extracting helper and re…
KRRT7 Jan 24, 2026
65ff392
add tests
KRRT7 Jan 24, 2026
b54dfca
Optimize detect_unused_helper_functions
codeflash-ai[bot] Jan 24, 2026
7e3d6ec
Merge pull request #1169 from codeflash-ai/codeflash/optimize-pr1166-…
KRRT7 Jan 25, 2026
214d891
Merge branch 'main' into skyvern-grace
KRRT7 Jan 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ jobs:

- name: Install uv
uses: astral-sh/setup-uv@v6
with:
version: "0.5.30"

- name: sync uv
run: |
uv venv --seed
uv sync


Expand Down
344 changes: 268 additions & 76 deletions codeflash/code_utils/code_extractor.py

Large diffs are not rendered by default.

708 changes: 416 additions & 292 deletions codeflash/context/code_context_extractor.py

Large diffs are not rendered by default.

110 changes: 55 additions & 55 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,18 @@ def visit_Name(self, node: cst.Name) -> None:
return

if name in self.definitions and name != self.current_top_level_name:
# skip if we are refrencing a class attribute and not a top-level definition
# Skip if this Name is the .attr part of an Attribute (e.g., 'x' in 'self.x')
# We only want to track the base/value of attribute access, not the attribute name itself
if self.class_depth > 0:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if parent is not None and isinstance(parent, cst.Attribute):
return
# Check if this Name is the .attr (property name), not the .value (base)
# If it's the .attr, skip it - attribute names aren't references to definitions
if parent.attr is node:
return
# If it's the .value (base), only skip if it's self/cls
if name in ("self", "cls"):
return
self.definitions[self.current_top_level_name].dependencies.add(name)


Expand Down Expand Up @@ -553,16 +560,6 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
return code


def print_definitions(definitions: dict[str, UsageInfo]) -> None:
"""Print information about each definition without the complex node object, used for debugging."""
print(f"Found {len(definitions)} definitions:")
for name, info in sorted(definitions.items()):
print(f" - Name: {name}")
print(f" Used by qualified function: {info.used_by_qualified_function}")
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
print()


def revert_unused_helper_functions(
project_root: Path, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
) -> None:
Expand Down Expand Up @@ -637,43 +634,40 @@ def _analyze_imports_in_optimized_code(
func_name = helper.only_function_name
module_name = helper.file_path.stem
# Cache function lookup for this (module, func)
file_entry = helpers_by_file_and_func[module_name]
if func_name in file_entry:
file_entry[func_name].append(helper)
else:
file_entry[func_name] = [helper]
helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper)
helpers_by_file[module_name].append(helper)

# Optimize attribute lookups and method binding outside the loop
helpers_by_file_and_func_get = helpers_by_file_and_func.get
helpers_by_file_get = helpers_by_file.get

for node in ast.walk(optimized_ast):
if isinstance(node, ast.ImportFrom):
# Handle "from module import function" statements
module_name = node.module
if module_name:
file_entry = helpers_by_file_and_func_get(module_name, None)
file_entry = helpers_by_file_and_func.get(module_name)
if file_entry:
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
original_name = alias.name
helpers = file_entry.get(original_name, None)
helpers = file_entry.get(original_name)
if helpers:
imported_set = imported_names_map[imported_name]
for helper in helpers:
imported_names_map[imported_name].add(helper.qualified_name)
imported_names_map[imported_name].add(helper.fully_qualified_name)
imported_set.add(helper.qualified_name)
imported_set.add(helper.fully_qualified_name)

elif isinstance(node, ast.Import):
# Handle "import module" statements
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
module_name = alias.name
for helper in helpers_by_file_get(module_name, []):
# For "import module" statements, functions would be called as module.function
full_call = f"{imported_name}.{helper.only_function_name}"
imported_names_map[full_call].add(helper.qualified_name)
imported_names_map[full_call].add(helper.fully_qualified_name)
helpers = helpers_by_file.get(module_name)
if helpers:
imported_set = imported_names_map[f"{imported_name}.{{func}}"]
for helper in helpers:
# For "import module" statements, functions would be called as module.function
full_call = f"{imported_name}.{helper.only_function_name}"
full_call_set = imported_names_map[full_call]
full_call_set.add(helper.qualified_name)
full_call_set.add(helper.fully_qualified_name)

return dict(imported_names_map)

Expand Down Expand Up @@ -753,27 +747,31 @@ def detect_unused_helper_functions(
called_name = node.func.id
called_function_names.add(called_name)
# Also add the qualified name if this is an imported function
if called_name in imported_names_map:
called_function_names.update(imported_names_map[called_name])
mapped_names = imported_names_map.get(called_name)
if mapped_names:
called_function_names.update(mapped_names)
elif isinstance(node.func, ast.Attribute):
# Method call: obj.method() or self.method() or module.function()
if isinstance(node.func.value, ast.Name):
if node.func.value.id == "self":
attr_name = node.func.attr
value_id = node.func.value.id
if value_id == "self":
# self.method_name() -> add both method_name and ClassName.method_name
called_function_names.add(node.func.attr)
called_function_names.add(attr_name)
# For class methods, also add the qualified name
# For class methods, also add the qualified name
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
class_name = function_to_optimize.parents[0].name
called_function_names.add(f"{class_name}.{node.func.attr}")
called_function_names.add(f"{class_name}.{attr_name}")
else:
# obj.method() or module.function()
attr_name = node.func.attr
called_function_names.add(attr_name)
called_function_names.add(f"{node.func.value.id}.{attr_name}")
full_call = f"{value_id}.{attr_name}"
called_function_names.add(full_call)
# Check if this is a module.function call that maps to a helper
full_call = f"{node.func.value.id}.{attr_name}"
if full_call in imported_names_map:
called_function_names.update(imported_names_map[full_call])
mapped_names = imported_names_map.get(full_call)
if mapped_names:
called_function_names.update(mapped_names)
# Handle nested attribute access like obj.attr.method()
# Handle nested attribute access like obj.attr.method()
else:
called_function_names.add(node.func.attr)
Expand All @@ -783,36 +781,38 @@ def detect_unused_helper_functions(

# Find helper functions that are no longer called
unused_helpers = []
entrypoint_file_path = function_to_optimize.file_path
for helper_function in code_context.helper_functions:
if helper_function.jedi_definition.type != "class":
# Check if the helper function is called using multiple name variants
helper_qualified_name = helper_function.qualified_name
helper_simple_name = helper_function.only_function_name
helper_fully_qualified_name = helper_function.fully_qualified_name

# Create a set of all possible names this helper might be called by
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}

# Check membership efficiently - exit early on first match
if (
helper_qualified_name in called_function_names
or helper_simple_name in called_function_names
or helper_fully_qualified_name in called_function_names
):
is_called = True
# For cross-file helpers, also consider module-based calls
if helper_function.file_path != function_to_optimize.file_path:
elif helper_function.file_path != entrypoint_file_path:
# Add potential module.function combinations
module_name = helper_function.file_path.stem
possible_call_names.add(f"{module_name}.{helper_simple_name}")

# Check if any of the possible names are in the called functions
is_called = bool(possible_call_names.intersection(called_function_names))
module_call = f"{module_name}.{helper_simple_name}"
is_called = module_call in called_function_names
else:
is_called = False

if not is_called:
unused_helpers.append(helper_function)
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
logger.debug(f" Checked names: {possible_call_names}")
else:
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")

ret_val = unused_helpers

except Exception as e:
logger.debug(f"Error detecting unused helper functions: {e}")
ret_val = []
return ret_val
return []
else:
return unused_helpers
25 changes: 19 additions & 6 deletions codeflash/verification/codeflash_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import dill as pickle
from dill import PicklingWarning

from codeflash.picklepatch.pickle_patcher import PicklePatcher

warnings.filterwarnings("ignore", category=PicklingWarning)


Expand Down Expand Up @@ -148,18 +150,29 @@ def wrapper(*args, **kwargs) -> None: # noqa: ANN002, ANN003
print(f"!######{test_stdout_tag}######!")

# Capture instance state after initialization
if hasattr(args[0], "__dict__"):
instance_state = args[
0
].__dict__ # self is always the first argument, this is ensured during instrumentation
# self is always the first argument, this is ensured during instrumentation
instance = args[0]
if hasattr(instance, "__dict__"):
instance_state = instance.__dict__
elif hasattr(instance, "__slots__"):
# For classes using __slots__, capture slot values
instance_state = {
slot: getattr(instance, slot, None) for slot in instance.__slots__ if hasattr(instance, slot)
}
else:
raise ValueError("Instance state could not be captured.")
# For C extension types or other special classes (e.g., Playwright's Page),
# capture all non-private, non-callable attributes
instance_state = {
attr: getattr(instance, attr)
for attr in dir(instance)
if not attr.startswith("_") and not callable(getattr(instance, attr, None))
}
codeflash_cur.execute(
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)"
)

# Write to sqlite
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(instance_state)
pickled_return_value = pickle.dumps(exception) if exception else PicklePatcher.dumps(instance_state)
codeflash_cur.execute(
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
Expand Down
12 changes: 10 additions & 2 deletions codeflash/verification/equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
test_diff_repr = reprlib_repr.repr


def safe_repr(obj: object) -> str:
"""Safely get repr of an object, handling Mock objects with corrupted state."""
try:
return repr(obj)
except (AttributeError, TypeError, RecursionError) as e:
return f"<repr failed: {type(e).__name__}: {e}>"


def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:
# This is meant to be only called with test results for the first loop index
if len(original_results) == 0 or len(candidate_results) == 0:
Expand Down Expand Up @@ -77,8 +85,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
test_diffs.append(
TestDiff(
scope=TestDiffScope.RETURN_VALUE,
original_value=test_diff_repr(repr(original_test_result.return_value)),
candidate_value=test_diff_repr(repr(cdd_test_result.return_value)),
original_value=test_diff_repr(safe_repr(original_test_result.return_value)),
candidate_value=test_diff_repr(safe_repr(cdd_test_result.return_value)),
test_src_code=original_test_result.id.get_src_code(original_test_result.file_name),
candidate_pytest_error=cdd_pytest_error,
original_pass=original_test_result.did_pass,
Expand Down
Loading
Loading