Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions pageindex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import openai
import logging
import os
import re
from datetime import datetime
import time
import json
Expand All @@ -16,9 +17,35 @@
import yaml
from pathlib import Path
from types import SimpleNamespace as config
from typing import Dict, Optional

CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")

# Singleton clients for reuse across calls, cached per API key.
_sync_clients: Dict[str, openai.OpenAI] = {}
_async_clients: Dict[str, openai.AsyncOpenAI] = {}


def _normalize_api_key(api_key: Optional[str]) -> str:
return api_key or CHATGPT_API_KEY or ""


def _get_sync_client(api_key: Optional[str] = None) -> openai.OpenAI:
"""Get or create the shared sync OpenAI client."""
key = _normalize_api_key(api_key)
if key not in _sync_clients:
_sync_clients[key] = openai.OpenAI(api_key=api_key or CHATGPT_API_KEY)
return _sync_clients[key]


def _get_async_client(api_key: Optional[str] = None) -> openai.AsyncOpenAI:
"""Get or create the shared async OpenAI client."""
key = _normalize_api_key(api_key)
if key not in _async_clients:
_async_clients[key] = openai.AsyncOpenAI(api_key=api_key or CHATGPT_API_KEY)
return _async_clients[key]


def count_tokens(text, model=None):
if not text:
return 0
Expand All @@ -28,12 +55,11 @@ def count_tokens(text, model=None):

def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
max_retries = 10
client = openai.OpenAI(api_key=api_key)
client = _get_sync_client(api_key)
for i in range(max_retries):
try:
if chat_history:
messages = chat_history
messages.append({"role": "user", "content": prompt})
messages = chat_history + [{"role": "user", "content": prompt}]
else:
messages = [{"role": "user", "content": prompt}]

Expand All @@ -54,18 +80,17 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_
time.sleep(1) # Wait for 1秒 before retrying
else:
logging.error('Max retries reached for prompt: ' + prompt)
return "Error"
return "", "error"



def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
max_retries = 10
client = openai.OpenAI(api_key=api_key)
client = _get_sync_client(api_key)
for i in range(max_retries):
try:
if chat_history:
messages = chat_history
messages.append({"role": "user", "content": prompt})
messages = chat_history + [{"role": "user", "content": prompt}]
else:
messages = [{"role": "user", "content": prompt}]

Expand All @@ -89,15 +114,15 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
max_retries = 10
messages = [{"role": "user", "content": prompt}]
client = _get_async_client(api_key)
for i in range(max_retries):
try:
async with openai.AsyncOpenAI(api_key=api_key) as client:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)
return response.choices[0].message.content
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)
return response.choices[0].message.content
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
Expand Down Expand Up @@ -709,4 +734,4 @@ def load(self, user_opt=None) -> config:

self._validate_keys(user_dict)
merged = {**self._default_dict, **user_dict}
return config(**merged)
return config(**merged)