diff --git a/pageindex/utils.py b/pageindex/utils.py index dc7acd888..7d661183f 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -2,6 +2,7 @@ import openai import logging import os +import re from datetime import datetime import time import json @@ -16,9 +17,35 @@ import yaml from pathlib import Path from types import SimpleNamespace as config +from typing import Dict, Optional CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") +# Singleton clients for reuse across calls, cached per API key. +_sync_clients: Dict[str, openai.OpenAI] = {} +_async_clients: Dict[str, openai.AsyncOpenAI] = {} + + +def _normalize_api_key(api_key: Optional[str]) -> str: + return api_key or CHATGPT_API_KEY or "" + + +def _get_sync_client(api_key: Optional[str] = None) -> openai.OpenAI: + """Get or create the shared sync OpenAI client.""" + key = _normalize_api_key(api_key) + if key not in _sync_clients: + _sync_clients[key] = openai.OpenAI(api_key=api_key or CHATGPT_API_KEY) + return _sync_clients[key] + + +def _get_async_client(api_key: Optional[str] = None) -> openai.AsyncOpenAI: + """Get or create the shared async OpenAI client.""" + key = _normalize_api_key(api_key) + if key not in _async_clients: + _async_clients[key] = openai.AsyncOpenAI(api_key=api_key or CHATGPT_API_KEY) + return _async_clients[key] + + def count_tokens(text, model=None): if not text: return 0 @@ -28,12 +55,11 @@ def count_tokens(text, model=None): def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + client = _get_sync_client(api_key) for i in range(max_retries): try: if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) + messages = chat_history + [{"role": "user", "content": prompt}] else: messages = [{"role": "user", "content": prompt}] @@ -54,18 +80,17 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ time.sleep(1) # Wait for 1秒 before retrying else: logging.error('Max retries reached for prompt: ' + prompt) - return "Error" + return "", "error" def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + client = _get_sync_client(api_key) for i in range(max_retries): try: if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) + messages = chat_history + [{"role": "user", "content": prompt}] else: messages = [{"role": "user", "content": prompt}] @@ -89,15 +114,15 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): max_retries = 10 messages = [{"role": "user", "content": prompt}] + client = _get_async_client(api_key) for i in range(max_retries): try: - async with openai.AsyncOpenAI(api_key=api_key) as client: - response = await client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - return response.choices[0].message.content + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + return response.choices[0].message.content except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") @@ -709,4 +734,4 @@ def load(self, user_opt=None) -> config: self._validate_keys(user_dict) merged = {**self._default_dict, **user_dict} - return config(**merged) \ No newline at end of file + return config(**merged)