diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 4ef96b308..96e467072 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -224,7 +224,7 @@ def get_jit_rewritten_code( # noqa: D417 logger.info("!lsp|Rewriting as a JIT function…") console.rule() try: - response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=60) + response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=self.timeout) except requests.exceptions.RequestException as e: logger.exception(f"Error generating jit rewritten candidate: {e}") ph("cli-jit-rewrite-error-caught", {"error": str(e)}) diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index df0ae57ce..e135cd022 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -84,6 +84,9 @@ def parse_args() -> Namespace: parser.add_argument( "--no-gen-tests", action="store_true", help="Do not generate tests, use only existing tests for optimization." ) + parser.add_argument( + "--no-jit-opts", action="store_true", help="Do not generate JIT-compiled optimizations for numerical code." + ) parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review") parser.add_argument( "--verify-setup", diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 697061b9c..256b65b7a 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -610,7 +610,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: ): console.rule() new_code_context = code_context - if self.is_numerical_code: # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax) + if ( + self.is_numerical_code and not self.args.no_jit_opts + ): # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax) jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code( code_context.read_writable_code.markdown, self.function_trace_id ) @@ -639,7 +641,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: read_writable_code=code_context.read_writable_code, read_only_context_code=code_context.read_only_context_code, run_experiment=should_run_experiment, - is_numerical_code=self.is_numerical_code, + is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts, ) concurrent.futures.wait([future_tests, future_optimizations]) @@ -1158,7 +1160,7 @@ def determine_best_candidate( ) if self.experiment_id else None, - is_numerical_code=self.is_numerical_code, + is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts, ) processor = CandidateProcessor(