Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,13 @@ def get_model_info(name: str) -> ModelInfo:
"Frontier Multimodal Language Model by Reka",
)

register_model_info(
["critique-agentic-search", "critique-agentic-search-api", "critique-labs-ai"],
"Critique Labs AI",
"https://www.critique-labs.ai/",
"Agentic Search Engine By Critique Labs AI",
)

register_model_info(
["gemini-pro", "gemini-pro-dev-api"],
"Gemini",
Expand Down
171 changes: 171 additions & 0 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,17 @@ def get_api_provider_stream_iter(
api_key=model_api_dict["api_key"],
extra_body=extra_body,
)
elif model_api_dict["api_type"] == "critique-labs-ai":
prompt = conv.to_openai_api_messages()
stream_iter = critique_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_key=model_api_dict.get("api_key"),
api_base=model_api_dict.get("api_base"),
)
else:
raise NotImplementedError()

Expand Down Expand Up @@ -1345,3 +1356,163 @@ def metagen_api_stream_iter(
"text": f"**API REQUEST ERROR** Reason: Unknown.",
"error_code": 1,
}


def critique_api_stream_iter(
model_name,
messages,
temperature,
top_p,
max_new_tokens,
api_key=None,
api_base=None,
):
import websockets
import threading
import queue
import json
import time

api_key = api_key or os.environ.get("CRITIQUE_API_KEY")
if not api_key:
yield {
"text": "**API REQUEST ERROR** Reason: CRITIQUE_API_KEY not found in environment variables.",
"error_code": 1,
}
return

# Combine all messages into a single prompt
prompt = ""
for message in messages:
if isinstance(message["content"], str):
role_prefix = (
f"{message['role'].capitalize()}: "
if message["role"] != "system"
else ""
)
prompt += f"{role_prefix}{message['content']}\n"
else: # Handle content that might be a list (for multimodal)
for content_item in message["content"]:
if content_item.get("type") == "text":
role_prefix = (
f"{message['role'].capitalize()}: "
if message["role"] != "system"
else ""
)
prompt += f"{role_prefix}{content_item['text']}\n"
prompt += "\n DO NOT RESPONSE IN MARKDOWN or provide any citations"

# Log request parameters
gen_params = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")

# Create a queue for communication between threads
response_queue = queue.Queue()
stop_event = threading.Event()
connection_closed = threading.Event()

# Thread function to handle WebSocket communication
def websocket_thread():
import asyncio

async def connect_and_stream():
uri = api_base or "wss://api.critique-labs.ai/v1/ws/search"

try:
# Create connection with headers in the correct format
async with websockets.connect(
uri, additional_headers={"X-API-Key": api_key}
) as websocket:
# Send the search request
await websocket.send(
json.dumps(
{
"prompt": prompt,
}
)
)

# Receive and process streaming responses
while not stop_event.is_set():
try:
response = await websocket.recv()
data = json.loads(response)
response_queue.put(data)

# If we get an error, we're done
if data["type"] == "error":
break
except websockets.exceptions.ConnectionClosed:
# This is the expected end signal - not an error
logger.info(
"WebSocket connection closed by server - this is the expected end signal"
)
connection_closed.set() # Signal that the connection was closed normally
break
except Exception as e:
# Only log as error for unexpected exceptions
logger.error(f"WebSocket error: {str(e)}")
response_queue.put(
{"type": "error", "content": f"WebSocket error: {str(e)}"}
)
finally:
# Always set connection_closed when we exit
connection_closed.set()

asyncio.run(connect_and_stream())

# Start the WebSocket thread
thread = threading.Thread(target=websocket_thread)
thread.daemon = True
thread.start()

try:
text = ""
context_info = []

# Process responses from the queue until connection is closed
while not connection_closed.is_set() or not response_queue.empty():
try:
# Wait for a response with timeout
data = response_queue.get(
timeout=0.5
) # Short timeout to check connection_closed frequently

if data["type"] == "response":
text += data["content"]
yield {
"text": text,
"error_code": 0,
}
elif data["type"] == "context":
# Collect context information
context_info.append(data["content"])
elif data["type"] == "error":
logger.error(f"Critique API error: {data['content']}")
yield {
"text": f"**API REQUEST ERROR** Reason: {data['content']}",
"error_code": 1,
}
break

response_queue.task_done()
except queue.Empty:
# Just a timeout to check if connection is closed
continue

except Exception as e:
logger.error(f"Error in critique_api_stream_iter: {str(e)}")
yield {
"text": f"**API REQUEST ERROR** Reason: {str(e)}",
"error_code": 1,
}
finally:
# Signal the thread to stop and wait for it to finish
stop_event.set()
thread.join(timeout=5)