diff --git a/optillm/cepo/cepo.py b/optillm/cepo/cepo.py index c9a7c2cf..be687244 100644 --- a/optillm/cepo/cepo.py +++ b/optillm/cepo/cepo.py @@ -1,4 +1,5 @@ import re +import logging import yaml import json import optillm @@ -14,6 +15,7 @@ from openai import BadRequestError as OpenAIBadRequestError from openai import InternalServerError as OpenAIInternalServerError +logger = logging.getLogger(__name__) @dataclass class CepoConfig: @@ -30,10 +32,11 @@ class CepoConfig: planning_temperature_step4: float # temperature for generator in step 4 of planning stage planning_max_tokens_step1: int # maximum number of tokens in step 1 of planning stage planning_max_tokens_step2: int # maximum number of tokens in step 2 of planning stage - planning_max_tokens_direct_resp: float # maximum number of tokens after step 2 if planning fails and answer directly + planning_max_tokens_direct_resp: int # maximum number of tokens after step 2 if planning fails and answer directly planning_max_tokens_step3: int # maximum number of tokens in step 3 of planning stage planning_max_tokens_step4: int # maximum number of tokens in step 4 of planning stage use_plan_diversity: bool # whether to use plan diversity + use_reasoning: bool # whether this model supports setting the reasoning effort parameter [True] use_reasoning_fallback: bool # whether to fallback to lower levels of reasoning when higher level fails num_of_retries: int # number of retries if llm call fails, 0 for no retries rating_model: Optional[str] = None # model to be used for rating @@ -240,10 +243,12 @@ def llm_call( return response_text, finish_reason, completion_tokens except (OpenAIBadRequestError, OpenAIInternalServerError) as e: + if logger.getEffectiveLevel() == logging.DEBUG: + logger.exception("Bad OpenAI Request") # Retry on 400 or 500 if attempt < retries - 1: sleep_time = 0.2 * (attempt + 1) - print(f"Got {e.__class__.__name__}, retrying in {sleep_time:.1f}s...") + logger.debug(f"Got {e.__class__.__name__}, retrying in {sleep_time:.1f}s...") time.sleep(sleep_time) continue raise @@ -302,7 +307,11 @@ def llm_call_reason_effort_fallback( for effort in reasoning_effort_levels: try: # Try with the current reasoning effort level - provider_request["reasoning_effort"] = effort + if cepo_config.use_reasoning: + provider_request["reasoning_effort"] = effort + else: + # Ensure reasoning_effort isn't set on the provider request + provider_request.pop("reasoning_effort", None) response, finish_reason, completion_tokens = llm_call( client=client, provider_request=provider_request, @@ -310,10 +319,17 @@ def llm_call_reason_effort_fallback( ) if response is not None and finish_reason != "length": return response, finish_reason, completion_tokens - print(f"Reasoning fallback from {effort}, to lower one") + logger.debug(f"Reasoning fallback from {effort}, to lower one") + except OpenAIBadRequestError as bre: + # Check to see if we attempted to use a reasoning effort not supported by the model + if len(reasoning_effort_levels) == 1 and bre.message.startswith("Error code: 400 - {'error': {'message': 'think value"): + logger.info(f"The think level {effort} was not supported by the model; Disabling thinking") + cepo_config.use_reasoning = False except (OpenAIBadRequestError, OpenAIInternalServerError) as e: # After 2 retries at this reasoning effort level it failed with error 400/500, lower level - print("400/500 persisted after retries at reasoning effort", effort, "→ degrading") + logger.debug(f"400/500 persisted after retries at reasoning effort {effort}; degrading effort") + if logger.getEffectiveLevel() == logging.DEBUG: + logger.exception("Bad OpenAI Request; Fallback") continue return None, "error", 0 @@ -403,7 +419,7 @@ def generate_single_plan(i): "temperature": cepo_config.planning_temperature_step1, "top_p": 1.0 } - + response, finish_reason, completion_tokens = llm_call_reason_effort_fallback( client=client, provider_request=provider_request, @@ -451,24 +467,24 @@ def generate_single_plan(i): provider_request = { "model": model, "messages": messages, - "max_tokens": cepo_config.planning_max_tokens_step2_direct, - "temperature":cepo_config.planning_temperature_step2_direct, + "max_tokens": cepo_config.planning_max_tokens_direct_resp, + "temperature":cepo_config.planning_temperature_direct_resp, "top_p": 0.95, - "reasoning_effort_levels": ["high", "medium", "low"] } - response, finish_reason, completion_tokens = llm_call_reason_effort_fallback( + response, finish_reason, tokens_used = llm_call_reason_effort_fallback( client=client, provider_request=provider_request, + reasoning_effort_levels=["high", "medium", "low"], cepo_config=cepo_config ) - local_completion_tokens += completion_tokens + completion_tokens += tokens_used # Log provider call if conversation logging is enabled if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id: response_dict = response.model_dump() if hasattr(response, 'model_dump') else response - optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict) - + optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict) + if response is None or finish_reason == "length": print("Direct answer failed, empty response or length") response = "" @@ -490,7 +506,7 @@ def generate_single_plan(i): user_content = f"Previous responses to review:\n\n{plans_message}\n\n{content}" messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}] - + provider_request = { "model": model, "messages": messages, @@ -498,14 +514,14 @@ def generate_single_plan(i): "temperature": cepo_config.planning_temperature_step1, "top_p": 1.0 } - - response, finish_reason, completion_tokens_ = llm_call_reason_effort_fallback( + + response, finish_reason, tokens_used = llm_call_reason_effort_fallback( client=client, provider_request=provider_request, reasoning_effort_levels=["high", "medium"], cepo_config=cepo_config ) - completion_tokens += completion_tokens_ + completion_tokens += tokens_used # Log provider call if conversation logging is enabled if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id: @@ -516,7 +532,7 @@ def generate_single_plan(i): print("Step 3 failed and only taking plans[0]") final_solution = plans[0] else: - completion_tokens += completion_tokens + #completion_tokens += completion_tokens # TODO: This seems like a bug final_solution = response messages.append({"role": "assistant", "content": final_solution}) @@ -524,7 +540,7 @@ def generate_single_plan(i): if cepo_config.planning_max_tokens_step4 != 0: content = f"Use your final solution from above to correctly answer the question. Here is the question:\n{task} /think" messages = [ - {"role": "system", "content": system_prompt}, + {"role": "system", "content": system_prompt}, {"role": "user", "content": f"Here's my final solution: {final_solution}\n\nNow {content}"} ] @@ -536,13 +552,13 @@ def generate_single_plan(i): "top_p": 1.0 } - response, finish_reason, completion_tokens_ = llm_call_reason_effort_fallback( + response, finish_reason, tokens_used = llm_call_reason_effort_fallback( client=client, provider_request=provider_request, reasoning_effort_levels=["high", "medium"], cepo_config=cepo_config ) - completion_tokens += completion_tokens_ + completion_tokens += tokens_used # Log provider call if conversation logging is enabled if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id: @@ -578,7 +594,7 @@ def generate_approaches(system_prompt: str, initial_query: str, num_approach: in f' ...\n'\ f'}}' messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": content}] - + retries = 0 while retries < max_retry: try: @@ -590,15 +606,15 @@ def generate_approaches(system_prompt: str, initial_query: str, num_approach: in "temperature": cepo_config.planning_temperature_step0, "stream": False, } - + response = client.chat.completions.create(**provider_request) - + # Log provider call if conversation logging is enabled if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id: response_dict = response.model_dump() if hasattr(response, 'model_dump') else response optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict) completion_tokens += response.usage.completion_tokens - completion = response.choices[0].message.content + completion = response.choices[0].message.content # Try to parse the completion as JSON, escape latex math symbols cleaned_completion = completion.replace('\\', '\\\\').replace('json','').replace("```", "") @@ -610,13 +626,13 @@ def generate_approaches(system_prompt: str, initial_query: str, num_approach: in # If there's an error, print a message and regenerate the content print(e) print(f"Parsing Error when generating diverse approaches, retrying... ({retries + 1}/{max_retry})") - + retries += 1 if retries == max_retry: print("Max retry attempts reached, returning empty list.") return [], 0 # Default approach - + return approaches, completion_tokens @@ -654,12 +670,12 @@ def generate_n_completions(system_prompt: str, initial_query: str, client: Any, ) cb_log["approaches"] = approaches completion_tokens += approach_completion_tokens - if cepo_config.print_output: - print(f"\nCePO: Plan diversity approaches ({cepo_config.bestofn_n}):\n{approaches}\n") + if cepo_config.print_output or logger.getEffectiveLevel() == logging.DEBUG: + logger.debug(f"\nCePO: Plan diversity approaches ({cepo_config.bestofn_n}):\n{approaches}\n") def run_single_completion(i): - if cepo_config.print_output: - print(f"\nCePO: Generating completion {i + 1} out of {cepo_config.bestofn_n} \n") + if cepo_config.print_output or logger.getEffectiveLevel() == logging.DEBUG: + logger.debug(f"\nCePO: Generating completion {i + 1} out of {cepo_config.bestofn_n} \n") approach = approaches[i] if approaches else None response_i, completion_tokens_i, cb_log_i = generate_completion(system_prompt, initial_query, client, model, cepo_config, approach, request_id) return i, response_i, completion_tokens_i, cb_log_i @@ -674,9 +690,9 @@ def run_single_completion(i): cb_log[f"completion_{i}_response"] = response_i cb_log[f"completion_{i}_log"] = cb_log_i cb_log[f"completion_{i}_completion_tokens"] = tokens_i - - if cepo_config.print_output: - print(f"\nCePO: All Answers generated!") + + if cepo_config.print_output or logger.getEffectiveLevel() == logging.DEBUG: + logger.debug(f"\nCePO: All Answers generated!") completions = [c if isinstance(c, str) else "" for c in completions] return completions, completion_tokens, cb_log @@ -685,7 +701,7 @@ def run_single_completion(i): def rate_completions_absolute(system_prompt: str, initial_query: str, client: Any, model: str, completions: list[str], cepo_config: CepoConfig, cb_log: dict, request_id: str = None) -> tuple[str, int, dict]: """ Rates completions for the Best of N step of CePO. Each completion is rated on a scale of 1 to 10 individually. - + Parameters: system_prompt (str): The system prompt to guide the model. initial_query (str): The task or question to be addressed. @@ -716,13 +732,13 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client: An "\"Rating: [[rating]]\", for example: \"Rating: [[0]]\"" rating_format_instruction = "\n\nRate the above response beginning with the detailed explanation followed by a rating of 0 or 1 by strictly following this format: \"Explanation: \n\nRating: [[rating]]\"" - + ratings = [] for i, completion in enumerate(completions): # Create a fresh conversation with proper role alternation for each completion system_content = f"USER QUESTION: {initial_query}\n\nRESPONSE: {completion}" rating_messages = [ - {"role": "system", "content": system_prompt + "\n\n" + rating_prompt}, + {"role": "system", "content": system_prompt + "\n\n" + rating_prompt}, {"role": "user", "content": system_content + rating_format_instruction} ] @@ -734,7 +750,7 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client: An "temperature": cepo_config.bestofn_temperature, "top_p": 1.0 } - + rating_response = client.chat.completions.create(**provider_request) rating_response, _, completion_tokens_ = llm_call_reason_effort_fallback( client=client, @@ -762,7 +778,7 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client: An ratings.append(float(rating_response)) except ValueError: ratings.append(-1) - + best_index = ratings.index(max(ratings)) cb_log["ratings"] = ratings cb_log["best_index"] = best_index @@ -816,7 +832,7 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client: An # Create a fresh conversation for each comparison with proper system→user structure rating_messages = [ - {"role": "system", "content": system_prompt + "\n\n" + rating_prompt}, + {"role": "system", "content": system_prompt + "\n\n" + rating_prompt}, {"role": "user", "content": comparison_content} ] @@ -828,15 +844,15 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client: An "temperature": cepo_config.bestofn_temperature } rating_response = client.chat.completions.create(**provider_request) - + # Log provider call if conversation logging is enabled if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id: response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict) - + completion_tokens += rating_response.usage.completion_tokens rating_response = rating_response.choices[0].message.content.strip() - + cb_log[f"rating_response_for_pair_{pair[0]}_{pair[1]}"] = rating_response if cepo_config.print_output: print(f"\nCePO: Rating response for pair {pair}: {rating_response}") @@ -852,7 +868,7 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client: An ratings[pair[0]] += 1 # if parsing unsuccessful, default to the first response else: ratings[pair[0]] += 1 # if parsing unsuccessful, default to the first response - + best_index = ratings.index(max(ratings)) cb_log["ratings"] = ratings cb_log["best_index"] = best_index @@ -905,12 +921,15 @@ def majority_vote_math(completions, last_n_chars=100): extracted_answer_map.append((response, extracted_answer)) counts = Counter(answer for _, answer in extracted_answer_map) - majority_answer, count = counts.most_common(1)[0] + common_list = counts.most_common(1) + majority_answer, count = common_list and common_list[0] or (None, 0) for response, answer in extracted_answer_map: if answer == majority_answer: return response, count - return extracted_answer_map[0][0], 0 + if extracted_answer_map and extracted_answer_map[0]: + return extracted_answer_map[0][0], 0 + return None, 0 def majority_vote_mcq(completions, last_n_chars=100): @@ -920,13 +939,16 @@ def majority_vote_mcq(completions, last_n_chars=100): extracted_answer_map.append((response, extracted_answer)) counts = Counter(answer for _, answer in extracted_answer_map) - majority_answer, count = counts.most_common(1)[0] + common_list = counts.most_common(1) + majority_answer, count = common_list and common_list[0] or (None, 0) for response, answer in extracted_answer_map: if answer == majority_answer: return response, count - return extracted_answer_map[0][0], 0 - + if extracted_answer_map and extracted_answer_map[0]: + return extracted_answer_map[0][0], 0 + return None, 0 + def rate_completions_majority(completions: list[str], last_n_chars: int = 150) -> tuple[str, int, dict]: mcq_majority, count = majority_vote_mcq(completions, last_n_chars) @@ -939,16 +961,16 @@ def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_c """ Applies CePO reasoning flow for the given task. First, it generates multiple completions, and then rates them to select the best one. Each completion is generated as follows: - + Generate `planning_n` solution proposals: - Step 1: Plan Generation - The model generates a detailed, step-by-step plan to solve the problem, along with its confidence level for + Step 1: Plan Generation - The model generates a detailed, step-by-step plan to solve the problem, along with its confidence level for each step. Step 2: Initial Solution - Using the plan from Step 1, the model produces an initial solution. - + Step 3: Plan Refinement - The model reviews all generated solution proposals and their associated plans, identifying inconsistencies. Based on this analysis, a refined, final step-by-step plan is constructed. Step 4: Final Solution - The model uses the refined plan from Step 3 to produce the final answer. - + Parameters: system_prompt (str): The system prompt to guide the model. initial_query (str): The task or question to be addressed. @@ -960,6 +982,8 @@ def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_c Tuple[str, int, dict]: The generated completion, number of tokens used """ + logger.info("CePO Started; Generating completions") + # Generate completions completions, completion_tokens_planning, cb_log = generate_n_completions(system_prompt, initial_query, client, model, cepo_config, request_id) # cb_log is a dictionary for debugging purposes completions = [c for c in completions if c] # safeguard in case completion is None (observed with GPT OSS) @@ -975,5 +999,5 @@ def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_c completion_tokens_rating = 0 else: raise ValueError("Invalid rating type in cepo_config") - + return best_completion, completion_tokens_planning + completion_tokens_rating diff --git a/optillm/cepo/configs/cepo_config.yaml b/optillm/cepo/configs/cepo_config.yaml index 7b28957a..7d747f12 100644 --- a/optillm/cepo/configs/cepo_config.yaml +++ b/optillm/cepo/configs/cepo_config.yaml @@ -16,6 +16,7 @@ planning_max_tokens_step3: 4096 planning_max_tokens_step4: 4096 use_plan_diversity: False rating_model: null +use_reasoning: True use_reasoning_fallback: False num_of_retries: 0 -print_output: False \ No newline at end of file +print_output: False diff --git a/optillm/cepo/configs/cepo_config_gptoss.yaml b/optillm/cepo/configs/cepo_config_gptoss.yaml index 72d78252..bd540a04 100644 --- a/optillm/cepo/configs/cepo_config_gptoss.yaml +++ b/optillm/cepo/configs/cepo_config_gptoss.yaml @@ -16,6 +16,7 @@ planning_max_tokens_step3: 40960 planning_max_tokens_step4: 40960 use_plan_diversity: False rating_model: null +use_reasoning: True use_reasoning_fallback: True num_of_retries: 2 -print_output: true \ No newline at end of file +print_output: true diff --git a/optillm/cepo/configs/cepo_config_qwen3.yaml b/optillm/cepo/configs/cepo_config_qwen3.yaml index af1b8dfa..1811d580 100644 --- a/optillm/cepo/configs/cepo_config_qwen3.yaml +++ b/optillm/cepo/configs/cepo_config_qwen3.yaml @@ -16,6 +16,7 @@ planning_max_tokens_step3: 20481 planning_max_tokens_step4: 20482 use_plan_diversity: False rating_model: null +use_reasoning: True use_reasoning_fallback: False num_of_retries: 0 -print_output: False \ No newline at end of file +print_output: False diff --git a/optillm/conversation_logger.py b/optillm/conversation_logger.py index 6c77cd3b..cdcd0c92 100644 --- a/optillm/conversation_logger.py +++ b/optillm/conversation_logger.py @@ -7,6 +7,7 @@ from typing import Dict, Any, Optional, List from dataclasses import dataclass, field import time +import copy logger = logging.getLogger(__name__) @@ -30,55 +31,55 @@ class ConversationEntry: class ConversationLogger: """ Logger for OptiLLM conversations including all provider interactions and metadata. - + Logs are saved in JSONL format (one JSON object per line) with daily rotation. Each entry contains the full conversation including all intermediate provider calls. """ - + def __init__(self, log_dir: Path, enabled: bool = False): self.enabled = enabled self.log_dir = log_dir self.active_entries: Dict[str, ConversationEntry] = {} self._lock = threading.Lock() - + if self.enabled: self.log_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Conversation logging enabled. Logs will be saved to: {self.log_dir}") else: logger.debug("Conversation logging disabled") - + def _get_log_file_path(self, timestamp: datetime = None) -> Path: """Get the log file path for a given timestamp (defaults to now)""" if timestamp is None: timestamp = datetime.now(timezone.utc) date_str = timestamp.strftime("%Y-%m-%d") return self.log_dir / f"conversations_{date_str}.jsonl" - + def _generate_request_id(self) -> str: """Generate a unique request ID""" return f"req_{uuid.uuid4().hex[:8]}" - - def start_conversation(self, - client_request: Dict[str, Any], - approach: str, + + def start_conversation(self, + client_request: Dict[str, Any], + approach: str, model: str) -> str: """ Start logging a new conversation. - + Args: client_request: The original request from the client approach: The optimization approach being used model: The model name - + Returns: str: Unique request ID for this conversation """ if not self.enabled: return "" - + request_id = self._generate_request_id() timestamp = datetime.now(timezone.utc).isoformat() - + entry = ConversationEntry( request_id=request_id, timestamp=timestamp, @@ -86,20 +87,20 @@ def start_conversation(self, model=model, client_request=client_request.copy() ) - + with self._lock: self.active_entries[request_id] = entry - + logger.debug(f"Started conversation logging for request {request_id}") return request_id - - def log_provider_call(self, - request_id: str, - provider_request: Dict[str, Any], + + def log_provider_call(self, + request_id: str, + provider_request: Dict[str, Any], provider_response: Dict[str, Any]) -> None: """ Log a provider API call and response. - + Args: request_id: The request ID for this conversation provider_request: The request sent to the provider @@ -107,86 +108,86 @@ def log_provider_call(self, """ if not self.enabled or not request_id: return - + with self._lock: entry = self.active_entries.get(request_id) if not entry: logger.warning(f"No active conversation found for request {request_id}") return - + call_data = { "call_number": len(entry.provider_calls) + 1, "timestamp": datetime.now(timezone.utc).isoformat(), - "request": provider_request.copy(), - "response": provider_response.copy() + "request": provider_request and provider_request.copy() or None, + "response": provider_response and copy.copy(provider_response) or None # Responses are usually strs or dicts } - + entry.provider_calls.append(call_data) - + logger.debug(f"Logged provider call #{len(entry.provider_calls)} for request {request_id}") - - def log_final_response(self, - request_id: str, + + def log_final_response(self, + request_id: str, final_response: Dict[str, Any]) -> None: """ Log the final response sent back to the client. - + Args: request_id: The request ID for this conversation final_response: The final response sent to the client """ if not self.enabled or not request_id: return - + with self._lock: entry = self.active_entries.get(request_id) if not entry: logger.warning(f"No active conversation found for request {request_id}") return - + entry.final_response = final_response.copy() entry.final_response["timestamp"] = datetime.now(timezone.utc).isoformat() - + def log_error(self, request_id: str, error: str) -> None: """ Log an error for this conversation. - + Args: - request_id: The request ID for this conversation + request_id: The request ID for this conversation error: Error message or description """ if not self.enabled or not request_id: return - + with self._lock: entry = self.active_entries.get(request_id) if not entry: logger.warning(f"No active conversation found for request {request_id}") return - + entry.error = error - + logger.debug(f"Logged error for request {request_id}: {error}") - + def finalize_conversation(self, request_id: str) -> None: """ Finalize and save the conversation to disk. - + Args: request_id: The request ID for this conversation """ if not self.enabled or not request_id: return - + with self._lock: entry = self.active_entries.pop(request_id, None) if not entry: logger.warning(f"No active conversation found for request {request_id}") return - + # Calculate total duration entry.total_duration_ms = int((time.time() - entry.start_time) * 1000) - + # Convert to dict for JSON serialization log_entry = { "timestamp": entry.timestamp, @@ -199,12 +200,12 @@ def finalize_conversation(self, request_id: str) -> None: "total_duration_ms": entry.total_duration_ms, "error": entry.error } - + # Write to log file self._write_log_entry(log_entry) - + logger.debug(f"Finalized conversation for request {request_id}") - + def _write_log_entry(self, log_entry: Dict[str, Any]) -> None: """Write a log entry to the appropriate JSONL file""" try: @@ -215,18 +216,18 @@ def _write_log_entry(self, log_entry: Dict[str, Any]) -> None: logger.debug(f"Wrote log entry to {log_file_path}") except Exception as e: logger.error(f"Failed to write log entry: {e}") - + def get_stats(self) -> Dict[str, Any]: """Get statistics about conversation logging""" with self._lock: active_count = len(self.active_entries) - + stats = { "enabled": self.enabled, "log_dir": str(self.log_dir), "active_conversations": active_count } - + if self.enabled: # Count total log files and approximate total entries log_files = list(self.log_dir.glob("conversations_*.jsonl")) @@ -237,12 +238,12 @@ def get_stats(self) -> Dict[str, Any]: total_entries += sum(1 for line in f if line.strip()) except Exception: pass - + stats.update({ "log_files_count": len(log_files), "total_entries_approximate": total_entries }) - + return stats @@ -262,4 +263,4 @@ def log_provider_call(request_id: str, provider_request: Dict[str, Any], provide def log_error(request_id: str, error_message: str) -> None: """Log an error using the global logger instance""" if _global_logger and _global_logger.enabled: - _global_logger.log_error(request_id, error_message) \ No newline at end of file + _global_logger.log_error(request_id, error_message) diff --git a/optillm/server.py b/optillm/server.py index 243c73c0..bc2c88f6 100644 --- a/optillm/server.py +++ b/optillm/server.py @@ -3,6 +3,7 @@ import os import secrets import time +import traceback from pathlib import Path from flask import Flask, request, jsonify from cerebras.cloud.sdk import Cerebras @@ -131,37 +132,37 @@ def get_config(): def count_reasoning_tokens(text: str, tokenizer=None) -> int: """ Count tokens within ... tags in the given text. - + Args: text: The text to analyze tokenizer: Optional tokenizer instance for precise counting - + Returns: Number of reasoning tokens (0 if no think tags found) """ if not text or not isinstance(text, str): return 0 - + # Extract all content within ... tags # Handle both complete and truncated think blocks - + # First, find all complete ... blocks complete_pattern = r'(.*?)' complete_matches = re.findall(complete_pattern, text, re.DOTALL) - + # Then check for unclosed tag (truncated response) # This finds that doesn't have a matching after it truncated_pattern = r'(?!.*)(.*)$' truncated_match = re.search(truncated_pattern, text, re.DOTALL) - + # Combine all thinking content thinking_content = ''.join(complete_matches) if truncated_match: thinking_content += truncated_match.group(1) - + if not thinking_content: return 0 - + if tokenizer and hasattr(tokenizer, 'encode'): # Use tokenizer for precise counting try: @@ -169,7 +170,7 @@ def count_reasoning_tokens(text: str, tokenizer=None) -> int: return len(tokens) except Exception as e: logger.warning(f"Failed to count tokens with tokenizer: {e}") - + # Fallback: rough estimation (4 chars per token on average, minimum 1 token for non-empty content) content_length = len(thinking_content.strip()) return max(1, content_length // 4) if content_length > 0 else 0 @@ -211,22 +212,22 @@ def normalize_message_content(messages): for message in messages: normalized_message = message.copy() content = message.get('content', '') - + # Convert list content to string if needed if isinstance(content, list): # Extract text content from the list text_content = ' '.join( - item.get('text', '') for item in content + item.get('text', '') for item in content if isinstance(item, dict) and item.get('type') == 'text' ) normalized_message['content'] = text_content - + normalized_messages.append(normalized_message) - + return normalized_messages def none_approach( - client: Any, + client: Any, model: str, original_messages: List[Dict[str, str]], request_id: str = None, @@ -234,48 +235,48 @@ def none_approach( ) -> Dict[str, Any]: """ Direct proxy approach that passes through all parameters to the underlying endpoint. - + Args: client: OpenAI client instance model: Model identifier original_messages: Original messages from the request request_id: Optional request ID for conversation logging **kwargs: Additional parameters to pass through - + Returns: Dict[str, Any]: Full OpenAI API response """ # Strip 'none-' prefix from model if present if model.startswith('none-'): model = model[5:] - + try: # Normalize message content to ensure it's always string normalized_messages = normalize_message_content(original_messages) - + # Prepare request data for logging provider_request = { "model": model, "messages": normalized_messages, **kwargs } - + # Make the direct completion call with normalized messages and parameters response = client.chat.completions.create( model=model, messages=normalized_messages, **kwargs ) - + # Convert to dict if it's not already response_dict = response.model_dump() if hasattr(response, 'model_dump') else response - + # Log the provider call if conversation logging is enabled if conversation_logger and request_id: conversation_logger.log_provider_call(request_id, provider_request, response_dict) - + return response_dict - + except Exception as e: # Log error if conversation logging is enabled if conversation_logger and request_id: @@ -286,45 +287,45 @@ def none_approach( def load_plugins(): # Clear existing plugins first but modify the global dict in place plugin_approaches.clear() - + # Get installed package plugins directory import optillm package_plugin_dir = os.path.join(os.path.dirname(optillm.__file__), 'plugins') - + # Get local project plugins directory current_dir = os.getcwd() if server_config.get("plugins_dir", "") == "" else server_config["plugins_dir"] local_plugin_dir = os.path.join(current_dir, 'optillm', 'plugins') - + plugin_dirs = [] - + # Add package plugin dir plugin_dirs.append((package_plugin_dir, "package")) - + # Add local plugin dir only if it's different from package dir if local_plugin_dir != package_plugin_dir: plugin_dirs.append((local_plugin_dir, "local")) - + for plugin_dir, source in plugin_dirs: logger.info(f"Looking for {source} plugins in: {plugin_dir}") - + if not os.path.exists(plugin_dir): logger.debug(f"{source.capitalize()} plugin directory not found: {plugin_dir}") continue - + plugin_files = glob.glob(os.path.join(plugin_dir, '*.py')) if not plugin_files: logger.debug(f"No plugin files found in {source} directory: {plugin_dir}") continue - + logger.info(f"Found {source} plugin files: {plugin_files}") - + for plugin_file in plugin_files: try: module_name = os.path.basename(plugin_file)[:-3] # Remove .py extension spec = importlib.util.spec_from_file_location(module_name, plugin_file) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + if hasattr(module, 'SLUG') and hasattr(module, 'run'): if module.SLUG in plugin_approaches: logger.info(f"Overriding {source} plugin: {module.SLUG}") @@ -334,7 +335,7 @@ def load_plugins(): logger.warning(f"Plugin {module_name} from {source} missing required attributes (SLUG and run)") except Exception as e: logger.error(f"Error loading {source} plugin {plugin_file}: {str(e)}") - + if not plugin_approaches: logger.warning("No plugins loaded from any location") @@ -343,17 +344,17 @@ def get_config_path(): import optillm package_config_dir = os.path.join(os.path.dirname(optillm.__file__), 'cepo', 'configs') package_config_path = os.path.join(package_config_dir, 'cepo_config.yaml') - + # Get local project config directory current_dir = os.getcwd() if server_config.get("config_dir", "") == "" else server_config["config_dir"] local_config_dir = os.path.join(current_dir, 'optillm', 'cepo', 'configs') local_config_path = os.path.join(local_config_dir, 'cepo_config.yaml') - + # If local config exists and is different from package config, use local if os.path.exists(local_config_path) and local_config_path != package_config_path: logger.debug(f"Using local config from: {local_config_path}") return local_config_path - + # Otherwise use package config logger.debug(f"Using package config from: {package_config_path}") return package_config_path @@ -391,7 +392,7 @@ def parse_combined_approach(model: str, known_approaches: list, plugin_approache actual_model = '-'.join(model_parts) return operation, approaches, actual_model - + def execute_single_approach(approach, system_prompt, initial_query, client, model, request_config: dict = None, request_id: str = None): if approach in known_approaches: if approach == 'none': @@ -401,14 +402,14 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode # Remove items that are handled separately by the framework # Note: 'n' is NOT removed - the none_approach passes it to the client which handles multiple completions kwargs.pop('stream', None) # stream is handled by proxy() - + # Reconstruct original messages from system_prompt and initial_query messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) if initial_query: messages.append({"role": "user", "content": initial_query}) - + logger.debug(f"none_approach kwargs: {kwargs}") response = none_approach(original_messages=messages, client=client, model=model, request_id=request_id, **kwargs) # For none approach, we return the response and a token count of 0 @@ -452,10 +453,10 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode plugin_func = plugin_approaches[approach] import inspect sig = inspect.signature(plugin_func) - + # Check if the plugin function is async is_async = inspect.iscoroutinefunction(plugin_func) - + if is_async: # For async functions, we need to run them in an event loop import asyncio @@ -481,7 +482,7 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode return plugin_func(system_prompt, initial_query, client, model) else: raise ValueError(f"Unknown approach: {approach}") - + def execute_combined_approaches(approaches, system_prompt, initial_query, client, model, request_config: dict = None): final_response = initial_query total_tokens = 0 @@ -504,7 +505,7 @@ def execute_n_times(n: int, approaches, operation: str, system_prompt: str, init request_config: dict = None, request_id: str = None) -> Tuple[Union[str, List[str]], int]: """ Execute the pipeline n times and return n responses. - + Args: n (int): Number of times to run the pipeline approaches (list): List of approaches to execute @@ -513,13 +514,13 @@ def execute_n_times(n: int, approaches, operation: str, system_prompt: str, init initial_query (str): Initial query client: OpenAI client instance model (str): Model identifier - + Returns: Tuple[Union[str, List[str]], int]: List of responses and total token count """ responses = [] total_tokens = 0 - + for _ in range(n): if operation == 'SINGLE': response, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id) @@ -532,7 +533,7 @@ def execute_n_times(n: int, approaches, operation: str, system_prompt: str, init loop.close() else: raise ValueError(f"Unknown operation: {operation}") - + # If response is already a list (from OR operation), extend responses # Otherwise append the single response if isinstance(response, list): @@ -540,7 +541,7 @@ def execute_n_times(n: int, approaches, operation: str, system_prompt: str, init else: responses.append(response) total_tokens += tokens - + # If n=1 and we got a single response, return it as is # Otherwise return the list of responses if n == 1 and len(responses) == 1: @@ -580,36 +581,36 @@ def extract_contents(response_obj): contents = [] # Handle both single response and list of responses responses = response_obj if isinstance(response_obj, list) else [response_obj] - + for response in responses: # Extract content from first choice if it exists - if (response.get('choices') and - len(response['choices']) > 0 and - response['choices'][0].get('message') and + if (response.get('choices') and + len(response['choices']) > 0 and + response['choices'][0].get('message') and response['choices'][0]['message'].get('content')): contents.append(response['choices'][0]['message']['content']) - + return contents def parse_conversation(messages): system_prompt = "" conversation = [] optillm_approach = None - + for message in messages: role = message['role'] content = message['content'] - + # Handle content that could be a list or string if isinstance(content, list): # Extract text content from the list text_content = ' '.join( - item['text'] for item in content + item['text'] for item in content if isinstance(item, dict) and item.get('type') == 'text' ) else: text_content = content - + if role == 'system': system_prompt, optillm_approach = extract_optillm_approach(text_content) elif role == 'user': @@ -618,35 +619,35 @@ def parse_conversation(messages): conversation.append(f"User: {text_content}") elif role == 'assistant': conversation.append(f"Assistant: {text_content}") - + initial_query = "\n".join(conversation) return system_prompt, initial_query, optillm_approach def tagged_conversation_to_messages(response_text): """Convert a tagged conversation string or list of strings into a list of messages. If the input doesn't contain User:/Assistant: tags, return it as is. - + Args: response_text: Either a string containing "User:" and "Assistant:" tags, or a list of such strings. - + Returns: If input has tags: A list of message dictionaries. If input has no tags: The original input. """ def has_conversation_tags(text): return "User:" in text or "Assistant:" in text - + def process_single_response(text): if not has_conversation_tags(text): return text - + messages = [] # Split on "User:" or "Assistant:" while keeping the delimiter parts = re.split(r'(?=(User:|Assistant:))', text.strip()) # Remove empty strings parts = [p for p in parts if p.strip()] - + for part in parts: part = part.strip() if part.startswith('User:'): @@ -668,7 +669,7 @@ def process_single_response(text): return response_text return processed else: - return process_single_response(response_text) + return process_single_response(response_text or "") def extract_optillm_approach(content): match = re.search(r'(.*?)', content) @@ -703,7 +704,7 @@ def proxy(): if auth_header and auth_header.startswith("Bearer "): bearer_token = auth_header.split("Bearer ")[1].strip() logger.debug(f"Intercepted Bearer Token: {bearer_token}") - + logger.debug(f'Request data: {data}') stream = data.get('stream', False) @@ -787,7 +788,7 @@ def proxy(): client = OpenAI(api_key=api_key, base_url=base_url) else: client = OpenAI(api_key=api_key) - else: + else: client = default_client try: @@ -807,22 +808,22 @@ def proxy(): 'stream': stream, 'optillm_approach': optillm_approach } - + logger.debug("Routing request to batch processor") result = request_batcher.add_request(batch_request_data) return jsonify(result), 200 - + except BatchingError as e: logger.error(f"Batch processing failed: {e}") return jsonify({"error": str(e)}), 500 - + # Check if any of the approaches is 'none' contains_none = any(approach == 'none' for approach in approaches) if operation == 'SINGLE' and approaches[0] == 'none': # Pass through the request including the n parameter result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id) - + logger.debug(f'Direct proxy response: {result}') # Log the final response and finalize conversation logging @@ -833,26 +834,26 @@ def proxy(): if stream: if request_id: logger.info(f'Request {request_id}: Completed (streaming response)') - return Response(generate_streaming_response(extract_contents(result), model), content_type='text/event-stream') + return Response(generate_streaming_response(extract_contents(result), model), content_type='text/event-stream') else : if request_id: logger.info(f'Request {request_id}: Completed') return jsonify(result), 200 - + elif operation == 'AND' or operation == 'OR': if contains_none: raise ValueError("'none' approach cannot be combined with other approaches") # Handle non-none approaches with n attempts response, completion_tokens = execute_n_times(n, approaches, operation, system_prompt, initial_query, client, model, request_config, request_id) - + # Check if the response is a full dict (like from proxy plugin or none approach) if operation == 'SINGLE' and isinstance(response, dict) and 'choices' in response and 'usage' in response: # This is a full response dict, return it directly if conversation_logger and request_id: conversation_logger.log_final_response(request_id, response) conversation_logger.finalize_conversation(request_id) - + if stream: if request_id: logger.info(f'Request {request_id}: Completed (streaming response)') @@ -867,9 +868,11 @@ def proxy(): if conversation_logger and request_id: conversation_logger.log_error(request_id, str(e)) conversation_logger.finalize_conversation(request_id) - + request_id_str = f' {request_id}' if request_id else '' logger.error(f"Error processing request{request_id_str}: {str(e)}") + if logger.getEffectiveLevel() == logging.DEBUG: + logger.exception("Debug request exception") return jsonify({"error": str(e)}), 500 # Convert tagged conversation to messages format if needed @@ -877,7 +880,7 @@ def proxy(): processed_response = tagged_conversation_to_messages(response) # If processed_response is a list of message lists, extract last message content if processed_response != response: # Only process if format changed - response = [msg[-1]['content'] if isinstance(msg, list) and msg else msg + response = [msg[-1]['content'] if isinstance(msg, list) and msg else msg for msg in processed_response] # Otherwise keep original response else: @@ -895,7 +898,7 @@ def proxy(): elif isinstance(response, list) and response: # For multiple responses, sum up reasoning tokens from all reasoning_tokens = sum(count_reasoning_tokens(resp) for resp in response if isinstance(resp, str)) - + response_data = { 'model': model, 'choices': [], @@ -955,7 +958,7 @@ def proxy_models(): # For local inference, create a models response manually current_model = server_config.get('model', 'gpt-3.5-turbo') models_data = { - "object": "list", + "object": "list", "data": [ { "id": current_model, @@ -983,8 +986,8 @@ def parse_args(): from optillm import __version__ as package_version except ImportError: package_version = "unknown" - - parser.add_argument('--version', action='version', + + parser.add_argument('--version', action='version', version=f'%(prog)s {package_version}', help="Show program's version number and exit") @@ -1055,7 +1058,7 @@ def parse_args(): batch_mode_default = os.environ.get("OPTILLM_BATCH_MODE", "false").lower() == "true" batch_size_default = int(os.environ.get("OPTILLM_BATCH_SIZE", 4)) batch_wait_ms_default = int(os.environ.get("OPTILLM_BATCH_WAIT_MS", 50)) - + parser.add_argument("--batch-mode", action="store_true", default=batch_mode_default, help="Enable automatic request batching (fail-fast, no fallback)") parser.add_argument("--batch-size", type=int, default=batch_size_default, @@ -1065,18 +1068,18 @@ def parse_args(): # Special handling of all the CePO Configurations for field in fields(CepoConfig): - parser.add_argument(f"--cepo_{field.name}", - dest=f"cepo_{field.name}", - type=field.type, - default=None, + parser.add_argument(f"--cepo_{field.name}", + dest=f"cepo_{field.name}", + type=field.type, + default=None, help=f"CePO configuration for {field.name}") - parser.add_argument("--cepo_config_file", - dest="cepo_config_file", - type=str, + parser.add_argument("--cepo_config_file", + dest="cepo_config_file", + type=str, default=default_config_path, help="Path to CePO configuration file") - + args = parser.parse_args() # Convert argument names to match server_config keys @@ -1094,16 +1097,16 @@ def main(): global request_batcher global conversation_logger # Call this function at the start of main() - + # Load plugins first so they're available in argument parser load_plugins() - + args = parse_args() # Update server_config with all argument values server_config.update(vars(args)) port = server_config['port'] - + # Initialize request batcher if batch mode is enabled if server_config.get('batch_mode', False): logger.info(f"Batch mode enabled: size={server_config['batch_size']}, " @@ -1113,47 +1116,47 @@ def main(): max_wait_ms=server_config['batch_wait_ms'], enable_logging=True ) - + # Set up the batch processor function def process_batch_requests(batch_requests): """ Process a batch of requests using true batching when possible - + Args: batch_requests: List of request data dictionaries - + Returns: List of response dictionaries """ import time from optillm.batching import BatchingError - + if not batch_requests: return [] - + logger.info(f"Processing batch of {len(batch_requests)} requests") - + # Check if we can use true batching (all requests compatible and using 'none' approach) can_use_true_batching = True first_req = batch_requests[0] - + # Check compatibility across all requests for req_data in batch_requests: - if (req_data['stream'] or + if (req_data['stream'] or req_data['approaches'] != first_req['approaches'] or req_data['operation'] != first_req['operation'] or req_data['model'] != first_req['model']): can_use_true_batching = False break - + # For now, implement sequential processing but with proper infrastructure # TODO: Implement true PyTorch/MLX batching in next phase responses = [] - + for i, req_data in enumerate(batch_requests): try: logger.debug(f"Processing batch request {i+1}/{len(batch_requests)}") - + # Extract request parameters system_prompt = req_data['system_prompt'] initial_query = req_data['initial_query'] @@ -1164,14 +1167,14 @@ def process_batch_requests(batch_requests): operation = req_data['operation'] n = req_data['n'] stream = req_data['stream'] - + # Validate request if stream: raise BatchingError("Streaming requests cannot be batched") - + # Check if any of the approaches is 'none' contains_none = any(approach == 'none' for approach in approaches) - + if operation == 'SINGLE' and approaches[0] == 'none': # Pass through the request including the n parameter result, completion_tokens = execute_single_approach( @@ -1186,18 +1189,18 @@ def process_batch_requests(batch_requests): # Handle non-none approaches with n attempts result, completion_tokens = execute_n_times( n, approaches, operation, system_prompt, initial_query, client, model, request_config) - + # Convert tagged conversation to messages format if needed if isinstance(result, list): processed_response = tagged_conversation_to_messages(result) if processed_response != result: # Only process if format changed - result = [msg[-1]['content'] if isinstance(msg, list) and msg else msg + result = [msg[-1]['content'] if isinstance(msg, list) and msg else msg for msg in processed_response] else: messages = tagged_conversation_to_messages(result) if isinstance(messages, list) and messages: # Only process if format changed result = messages[-1]['content'] - + # Generate the response in OpenAI format if isinstance(result, list): choices = [] @@ -1214,12 +1217,12 @@ def process_batch_requests(batch_requests): choices = [{ "index": 0, "message": { - "role": "assistant", + "role": "assistant", "content": result }, "finish_reason": "stop" }] - + response_dict = { "id": f"chatcmpl-{int(time.time()*1000)}-{i}", "object": "chat.completion", @@ -1232,16 +1235,16 @@ def process_batch_requests(batch_requests): "total_tokens": completion_tokens if isinstance(completion_tokens, int) else 0 } } - + responses.append(response_dict) - + except Exception as e: logger.error(f"Error processing batch request {i+1}: {e}") raise BatchingError(f"Failed to process request {i+1}: {str(e)}") - + logger.info(f"Completed batch processing of {len(responses)} requests") return responses - + # Set the processor function on the batcher request_batcher.set_processor(process_batch_requests) @@ -1249,7 +1252,7 @@ def process_batch_requests(batch_requests): logging_level = server_config['log'] if logging_level in logging_levels.keys(): logger.setLevel(logging_levels[logging_level]) - + # Initialize conversation logger if enabled global conversation_logger conversation_logger = ConversationLogger( @@ -1260,12 +1263,12 @@ def process_batch_requests(batch_requests): optillm.conversation_logger.set_global_logger(conversation_logger) if server_config['log_conversations']: logger.info(f"Conversation logging enabled. Logs will be saved to: {server_config['conversation_log_dir']}") - + # set and log the cepo configs cepo_config = init_cepo_config(server_config) if args.approach == 'cepo': logger.info(f"CePO Config: {cepo_config}") - + logger.info(f"Starting server with approach: {server_config['approach']}") server_config_clean = server_config.copy() if server_config_clean['optillm_api_key']: @@ -1282,16 +1285,16 @@ def process_batch_requests(batch_requests): server_thread = threading.Thread(target=app.run, kwargs={'host': host, 'port': port}) server_thread.daemon = True server_thread.start() - + # Configure the base URL for the Gradio interface base_url = f"http://localhost:{port}/v1" logger.info(f"Launching Gradio interface connected to {base_url}") - + # Create custom chat function with extended timeout def chat_with_optillm(message, history): import httpx from openai import OpenAI - + # Create client with extended timeout and no retries custom_client = OpenAI( api_key="optillm", @@ -1299,7 +1302,7 @@ def chat_with_optillm(message, history): timeout=httpx.Timeout(1800.0, connect=5.0), # 30 min timeout max_retries=0 # No retries - prevents duplicate requests ) - + # Convert history to messages format messages = [] for h in history: @@ -1308,7 +1311,7 @@ def chat_with_optillm(message, history): if h[1]: # Assistant message messages.append({"role": "assistant", "content": h[1]}) messages.append({"role": "user", "content": message}) - + # Make request try: response = custom_client.chat.completions.create( @@ -1318,7 +1321,7 @@ def chat_with_optillm(message, history): return response.choices[0].message.content except Exception as e: return f"Error: {str(e)}" - + # Create Gradio interface with queue for long operations demo = gr.ChatInterface( chat_with_optillm, @@ -1330,8 +1333,8 @@ def chat_with_optillm(message, history): except ImportError: logger.error("Gradio is required for GUI. Install it with: pip install gradio") return - + app.run(host=server_config['host'], port=port) if __name__ == "__main__": - main() \ No newline at end of file + main()