From 0fa45e50b239b6b03f8f13bcf2110564d4cbd0d7 Mon Sep 17 00:00:00 2001 From: Parth Bhagat Date: Mon, 14 Apr 2025 22:55:48 -0700 Subject: [PATCH] black reformatting add support for critique api for search arena. --- fastchat/model/model_registry.py | 7 ++ fastchat/serve/api_provider.py | 171 +++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+) diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index 2eed9649e..71d927f3f 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -270,6 +270,13 @@ def get_model_info(name: str) -> ModelInfo: "Frontier Multimodal Language Model by Reka", ) +register_model_info( + ["critique-agentic-search", "critique-agentic-search-api", "critique-labs-ai"], + "Critique Labs AI", + "https://www.critique-labs.ai/", + "Agentic Search Engine By Critique Labs AI", +) + register_model_info( ["gemini-pro", "gemini-pro-dev-api"], "Gemini", diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index 2e967e3ef..441038b2c 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -259,6 +259,17 @@ def get_api_provider_stream_iter( api_key=model_api_dict["api_key"], extra_body=extra_body, ) + elif model_api_dict["api_type"] == "critique-labs-ai": + prompt = conv.to_openai_api_messages() + stream_iter = critique_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_key=model_api_dict.get("api_key"), + api_base=model_api_dict.get("api_base"), + ) else: raise NotImplementedError() @@ -1345,3 +1356,163 @@ def metagen_api_stream_iter( "text": f"**API REQUEST ERROR** Reason: Unknown.", "error_code": 1, } + + +def critique_api_stream_iter( + model_name, + messages, + temperature, + top_p, + max_new_tokens, + api_key=None, + api_base=None, +): + import websockets + import threading + import queue + import json + import time + + api_key = api_key or os.environ.get("CRITIQUE_API_KEY") + if not api_key: + yield { + "text": "**API REQUEST ERROR** Reason: CRITIQUE_API_KEY not found in environment variables.", + "error_code": 1, + } + return + + # Combine all messages into a single prompt + prompt = "" + for message in messages: + if isinstance(message["content"], str): + role_prefix = ( + f"{message['role'].capitalize()}: " + if message["role"] != "system" + else "" + ) + prompt += f"{role_prefix}{message['content']}\n" + else: # Handle content that might be a list (for multimodal) + for content_item in message["content"]: + if content_item.get("type") == "text": + role_prefix = ( + f"{message['role'].capitalize()}: " + if message["role"] != "system" + else "" + ) + prompt += f"{role_prefix}{content_item['text']}\n" + prompt += "\n DO NOT RESPONSE IN MARKDOWN or provide any citations" + + # Log request parameters + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + # Create a queue for communication between threads + response_queue = queue.Queue() + stop_event = threading.Event() + connection_closed = threading.Event() + + # Thread function to handle WebSocket communication + def websocket_thread(): + import asyncio + + async def connect_and_stream(): + uri = api_base or "wss://api.critique-labs.ai/v1/ws/search" + + try: + # Create connection with headers in the correct format + async with websockets.connect( + uri, additional_headers={"X-API-Key": api_key} + ) as websocket: + # Send the search request + await websocket.send( + json.dumps( + { + "prompt": prompt, + } + ) + ) + + # Receive and process streaming responses + while not stop_event.is_set(): + try: + response = await websocket.recv() + data = json.loads(response) + response_queue.put(data) + + # If we get an error, we're done + if data["type"] == "error": + break + except websockets.exceptions.ConnectionClosed: + # This is the expected end signal - not an error + logger.info( + "WebSocket connection closed by server - this is the expected end signal" + ) + connection_closed.set() # Signal that the connection was closed normally + break + except Exception as e: + # Only log as error for unexpected exceptions + logger.error(f"WebSocket error: {str(e)}") + response_queue.put( + {"type": "error", "content": f"WebSocket error: {str(e)}"} + ) + finally: + # Always set connection_closed when we exit + connection_closed.set() + + asyncio.run(connect_and_stream()) + + # Start the WebSocket thread + thread = threading.Thread(target=websocket_thread) + thread.daemon = True + thread.start() + + try: + text = "" + context_info = [] + + # Process responses from the queue until connection is closed + while not connection_closed.is_set() or not response_queue.empty(): + try: + # Wait for a response with timeout + data = response_queue.get( + timeout=0.5 + ) # Short timeout to check connection_closed frequently + + if data["type"] == "response": + text += data["content"] + yield { + "text": text, + "error_code": 0, + } + elif data["type"] == "context": + # Collect context information + context_info.append(data["content"]) + elif data["type"] == "error": + logger.error(f"Critique API error: {data['content']}") + yield { + "text": f"**API REQUEST ERROR** Reason: {data['content']}", + "error_code": 1, + } + break + + response_queue.task_done() + except queue.Empty: + # Just a timeout to check if connection is closed + continue + + except Exception as e: + logger.error(f"Error in critique_api_stream_iter: {str(e)}") + yield { + "text": f"**API REQUEST ERROR** Reason: {str(e)}", + "error_code": 1, + } + finally: + # Signal the thread to stop and wait for it to finish + stop_event.set() + thread.join(timeout=5)