diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 6ddfe763a..fe4e5cf89 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -3,6 +3,7 @@ import ast import time from dataclasses import dataclass +from functools import lru_cache from importlib.util import find_spec from itertools import chain from pathlib import Path @@ -513,7 +514,7 @@ def visit_Try(self, node: cst.Try) -> None: def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]: """Extract global statements from source code.""" - module = cst.parse_module(source_code) + module = _parse_module_cached(source_code) collector = GlobalStatementCollector() module.visit(collector) return module, collector.global_statements @@ -1581,3 +1582,15 @@ def get_opt_review_metrics( end_time = time.perf_counter() logger.debug(f"Got function references in {end_time - start_time:.2f} seconds") return calling_fns_details + + +@lru_cache(maxsize=128) +def _parse_module_cached(source_code: str) -> cst.Module: + """Cache parsed modules to avoid re-parsing the same source code.""" + return cst.parse_module(source_code) + + +@lru_cache(maxsize=128) +def _parse_module_cached(source_code: str) -> cst.Module: + """Cache parsed modules to avoid re-parsing the same source code.""" + return cst.parse_module(source_code)